support federation queries through http connect proxy (#10475)

Signed-off-by: Marcus Hoffmann <bubu@bubu1.eu>
Signed-off-by: Dirk Klimpel dirk@klimpel.org
pull/10604/head
Dirk Klimpel 2021-08-11 16:34:59 +02:00 committed by GitHub
parent 8c654b7309
commit 339c3918e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 557 additions and 193 deletions

View File

@ -0,0 +1 @@
Add support for sending federation requests through a proxy. Contributed by @Bubu and @dklimpel.

View File

@ -45,18 +45,18 @@ The proxy will be **used** for:
- recaptcha validation - recaptcha validation
- CAS auth validation - CAS auth validation
- OpenID Connect - OpenID Connect
- Outbound federation
- Federation (checking public key revocation) - Federation (checking public key revocation)
- Fetching public keys of other servers
- Downloading remote media
It will **not be used** for: It will **not be used** for:
- Application Services - Application Services
- Identity servers - Identity servers
- Outbound federation
- In worker configurations - In worker configurations
- connections between workers - connections between workers
- connections from workers to Redis - connections from workers to Redis
- Fetching public keys of other servers
- Downloading remote media
## Troubleshooting ## Troubleshooting

View File

@ -86,6 +86,33 @@ process, for example:
``` ```
# Upgrading to v1.xx.0
## Add support for routing outbound HTTP requests via a proxy for federation
Since Synapse 1.6.0 (2019-11-26) you can set a proxy for outbound HTTP requests via
http_proxy/https_proxy environment variables. This proxy was set for:
- push
- url previews
- phone-home stats
- recaptcha validation
- CAS auth validation
- OpenID Connect
- Federation (checking public key revocation)
In this version we have added support for outbound requests for:
- Outbound federation
- Downloading remote media
- Fetching public keys of other servers
These requests use the same proxy configuration. If you have a proxy configuration we
recommend to verify the configuration. It may be necessary to adjust the `no_proxy`
environment variable.
See [using a forward proxy with Synapse documentation](setup/forward_proxy.md) for
details.
# Upgrading to v1.39.0 # Upgrading to v1.39.0
## Deprecation of the current third-party rules module interface ## Deprecation of the current third-party rules module interface

View File

@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64
import logging import logging
from typing import Optional
import attr
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import defer, protocol from twisted.internet import defer, protocol
@ -21,7 +24,6 @@ from twisted.internet.error import ConnectError
from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
from twisted.internet.protocol import ClientFactory, Protocol, connectionDone from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
from twisted.web import http from twisted.web import http
from twisted.web.http_headers import Headers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,6 +32,22 @@ class ProxyConnectError(ConnectError):
pass pass
@attr.s
class ProxyCredentials:
username_password = attr.ib(type=bytes)
def as_proxy_authorization_value(self) -> bytes:
"""
Return the value for a Proxy-Authorization header (i.e. 'Basic abdef==').
Returns:
A transformation of the authentication string the encoded value for
a Proxy-Authorization header.
"""
# Encode as base64 and prepend the authorization type
return b"Basic " + base64.encodebytes(self.username_password)
@implementer(IStreamClientEndpoint) @implementer(IStreamClientEndpoint)
class HTTPConnectProxyEndpoint: class HTTPConnectProxyEndpoint:
"""An Endpoint implementation which will send a CONNECT request to an http proxy """An Endpoint implementation which will send a CONNECT request to an http proxy
@ -46,7 +64,7 @@ class HTTPConnectProxyEndpoint:
proxy_endpoint: the endpoint to use to connect to the proxy proxy_endpoint: the endpoint to use to connect to the proxy
host: hostname that we want to CONNECT to host: hostname that we want to CONNECT to
port: port that we want to connect to port: port that we want to connect to
headers: Extra HTTP headers to include in the CONNECT request proxy_creds: credentials to authenticate at proxy
""" """
def __init__( def __init__(
@ -55,20 +73,20 @@ class HTTPConnectProxyEndpoint:
proxy_endpoint: IStreamClientEndpoint, proxy_endpoint: IStreamClientEndpoint,
host: bytes, host: bytes,
port: int, port: int,
headers: Headers, proxy_creds: Optional[ProxyCredentials],
): ):
self._reactor = reactor self._reactor = reactor
self._proxy_endpoint = proxy_endpoint self._proxy_endpoint = proxy_endpoint
self._host = host self._host = host
self._port = port self._port = port
self._headers = headers self._proxy_creds = proxy_creds
def __repr__(self): def __repr__(self):
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,) return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
def connect(self, protocolFactory: ClientFactory): def connect(self, protocolFactory: ClientFactory):
f = HTTPProxiedClientFactory( f = HTTPProxiedClientFactory(
self._host, self._port, protocolFactory, self._headers self._host, self._port, protocolFactory, self._proxy_creds
) )
d = self._proxy_endpoint.connect(f) d = self._proxy_endpoint.connect(f)
# once the tcp socket connects successfully, we need to wait for the # once the tcp socket connects successfully, we need to wait for the
@ -87,7 +105,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
dst_host: hostname that we want to CONNECT to dst_host: hostname that we want to CONNECT to
dst_port: port that we want to connect to dst_port: port that we want to connect to
wrapped_factory: The original Factory wrapped_factory: The original Factory
headers: Extra HTTP headers to include in the CONNECT request proxy_creds: credentials to authenticate at proxy
""" """
def __init__( def __init__(
@ -95,12 +113,12 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
dst_host: bytes, dst_host: bytes,
dst_port: int, dst_port: int,
wrapped_factory: ClientFactory, wrapped_factory: ClientFactory,
headers: Headers, proxy_creds: Optional[ProxyCredentials],
): ):
self.dst_host = dst_host self.dst_host = dst_host
self.dst_port = dst_port self.dst_port = dst_port
self.wrapped_factory = wrapped_factory self.wrapped_factory = wrapped_factory
self.headers = headers self.proxy_creds = proxy_creds
self.on_connection = defer.Deferred() self.on_connection = defer.Deferred()
def startedConnecting(self, connector): def startedConnecting(self, connector):
@ -114,7 +132,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
self.dst_port, self.dst_port,
wrapped_protocol, wrapped_protocol,
self.on_connection, self.on_connection,
self.headers, self.proxy_creds,
) )
def clientConnectionFailed(self, connector, reason): def clientConnectionFailed(self, connector, reason):
@ -145,7 +163,7 @@ class HTTPConnectProtocol(protocol.Protocol):
connected_deferred: a Deferred which will be callbacked with connected_deferred: a Deferred which will be callbacked with
wrapped_protocol when the CONNECT completes wrapped_protocol when the CONNECT completes
headers: Extra HTTP headers to include in the CONNECT request proxy_creds: credentials to authenticate at proxy
""" """
def __init__( def __init__(
@ -154,16 +172,16 @@ class HTTPConnectProtocol(protocol.Protocol):
port: int, port: int,
wrapped_protocol: Protocol, wrapped_protocol: Protocol,
connected_deferred: defer.Deferred, connected_deferred: defer.Deferred,
headers: Headers, proxy_creds: Optional[ProxyCredentials],
): ):
self.host = host self.host = host
self.port = port self.port = port
self.wrapped_protocol = wrapped_protocol self.wrapped_protocol = wrapped_protocol
self.connected_deferred = connected_deferred self.connected_deferred = connected_deferred
self.headers = headers self.proxy_creds = proxy_creds
self.http_setup_client = HTTPConnectSetupClient( self.http_setup_client = HTTPConnectSetupClient(
self.host, self.port, self.headers self.host, self.port, self.proxy_creds
) )
self.http_setup_client.on_connected.addCallback(self.proxyConnected) self.http_setup_client.on_connected.addCallback(self.proxyConnected)
@ -205,30 +223,38 @@ class HTTPConnectSetupClient(http.HTTPClient):
Args: Args:
host: The hostname to send in the CONNECT message host: The hostname to send in the CONNECT message
port: The port to send in the CONNECT message port: The port to send in the CONNECT message
headers: Extra headers to send with the CONNECT message proxy_creds: credentials to authenticate at proxy
""" """
def __init__(self, host: bytes, port: int, headers: Headers): def __init__(
self,
host: bytes,
port: int,
proxy_creds: Optional[ProxyCredentials],
):
self.host = host self.host = host
self.port = port self.port = port
self.headers = headers self.proxy_creds = proxy_creds
self.on_connected = defer.Deferred() self.on_connected = defer.Deferred()
def connectionMade(self): def connectionMade(self):
logger.debug("Connected to proxy, sending CONNECT") logger.debug("Connected to proxy, sending CONNECT")
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
# Send any additional specified headers # Determine whether we need to set Proxy-Authorization headers
for name, values in self.headers.getAllRawHeaders(): if self.proxy_creds:
for value in values: # Set a Proxy-Authorization header
self.sendHeader(name, value) self.sendHeader(
b"Proxy-Authorization",
self.proxy_creds.as_proxy_authorization_value(),
)
self.endHeaders() self.endHeaders()
def handleStatus(self, version: bytes, status: bytes, message: bytes): def handleStatus(self, version: bytes, status: bytes, message: bytes):
logger.debug("Got Status: %s %s %s", status, message, version) logger.debug("Got Status: %s %s %s", status, message, version)
if status != b"200": if status != b"200":
raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")
def handleEndHeaders(self): def handleEndHeaders(self):
logger.debug("End Headers") logger.debug("End Headers")

View File

@ -14,6 +14,10 @@
import logging import logging
import urllib.parse import urllib.parse
from typing import Any, Generator, List, Optional from typing import Any, Generator, List, Optional
from urllib.request import ( # type: ignore[attr-defined]
getproxies_environment,
proxy_bypass_environment,
)
from netaddr import AddrFormatError, IPAddress, IPSet from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer from zope.interface import implementer
@ -30,9 +34,12 @@ from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer, IResponse from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer, IResponse
from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.client import BlacklistingAgentWrapper from synapse.http import proxyagent
from synapse.http.client import BlacklistingAgentWrapper, BlacklistingReactorWrapper
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import ISynapseReactor from synapse.types import ISynapseReactor
from synapse.util import Clock from synapse.util import Clock
@ -57,6 +64,14 @@ class MatrixFederationAgent:
user_agent: user_agent:
The user agent header to use for federation requests. The user agent header to use for federation requests.
ip_whitelist: Allowed IP addresses.
ip_blacklist: Disallowed IP addresses.
proxy_reactor: twisted reactor to use for connections to the proxy server
reactor might have some blacklisting applied (i.e. for DNS queries),
but we need unblocked access to the proxy.
_srv_resolver: _srv_resolver:
SrvResolver implementation to use for looking up SRV records. None SrvResolver implementation to use for looking up SRV records. None
to use a default implementation. to use a default implementation.
@ -71,11 +86,18 @@ class MatrixFederationAgent:
reactor: ISynapseReactor, reactor: ISynapseReactor,
tls_client_options_factory: Optional[FederationPolicyForHTTPS], tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes, user_agent: bytes,
ip_whitelist: IPSet,
ip_blacklist: IPSet, ip_blacklist: IPSet,
_srv_resolver: Optional[SrvResolver] = None, _srv_resolver: Optional[SrvResolver] = None,
_well_known_resolver: Optional[WellKnownResolver] = None, _well_known_resolver: Optional[WellKnownResolver] = None,
): ):
self._reactor = reactor # proxy_reactor is not blacklisted
proxy_reactor = reactor
# We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding.
reactor = BlacklistingReactorWrapper(reactor, ip_whitelist, ip_blacklist)
self._clock = Clock(reactor) self._clock = Clock(reactor)
self._pool = HTTPConnectionPool(reactor) self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False self._pool.retryAutomatically = False
@ -83,24 +105,27 @@ class MatrixFederationAgent:
self._pool.cachedConnectionTimeout = 2 * 60 self._pool.cachedConnectionTimeout = 2 * 60
self._agent = Agent.usingEndpointFactory( self._agent = Agent.usingEndpointFactory(
self._reactor, reactor,
MatrixHostnameEndpointFactory( MatrixHostnameEndpointFactory(
reactor, tls_client_options_factory, _srv_resolver reactor,
proxy_reactor,
tls_client_options_factory,
_srv_resolver,
), ),
pool=self._pool, pool=self._pool,
) )
self.user_agent = user_agent self.user_agent = user_agent
if _well_known_resolver is None: if _well_known_resolver is None:
# Note that the name resolver has already been wrapped in a
# IPBlacklistingResolver by MatrixFederationHttpClient.
_well_known_resolver = WellKnownResolver( _well_known_resolver = WellKnownResolver(
self._reactor, reactor,
agent=BlacklistingAgentWrapper( agent=BlacklistingAgentWrapper(
Agent( ProxyAgent(
self._reactor, reactor,
proxy_reactor,
pool=self._pool, pool=self._pool,
contextFactory=tls_client_options_factory, contextFactory=tls_client_options_factory,
use_proxy=True,
), ),
ip_blacklist=ip_blacklist, ip_blacklist=ip_blacklist,
), ),
@ -200,10 +225,12 @@ class MatrixHostnameEndpointFactory:
def __init__( def __init__(
self, self,
reactor: IReactorCore, reactor: IReactorCore,
proxy_reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS], tls_client_options_factory: Optional[FederationPolicyForHTTPS],
srv_resolver: Optional[SrvResolver], srv_resolver: Optional[SrvResolver],
): ):
self._reactor = reactor self._reactor = reactor
self._proxy_reactor = proxy_reactor
self._tls_client_options_factory = tls_client_options_factory self._tls_client_options_factory = tls_client_options_factory
if srv_resolver is None: if srv_resolver is None:
@ -211,9 +238,10 @@ class MatrixHostnameEndpointFactory:
self._srv_resolver = srv_resolver self._srv_resolver = srv_resolver
def endpointForURI(self, parsed_uri): def endpointForURI(self, parsed_uri: URI):
return MatrixHostnameEndpoint( return MatrixHostnameEndpoint(
self._reactor, self._reactor,
self._proxy_reactor,
self._tls_client_options_factory, self._tls_client_options_factory,
self._srv_resolver, self._srv_resolver,
parsed_uri, parsed_uri,
@ -227,23 +255,45 @@ class MatrixHostnameEndpoint:
Args: Args:
reactor: twisted reactor to use for underlying requests reactor: twisted reactor to use for underlying requests
proxy_reactor: twisted reactor to use for connections to the proxy server.
'reactor' might have some blacklisting applied (i.e. for DNS queries),
but we need unblocked access to the proxy.
tls_client_options_factory: tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS. factory to use for fetching client tls options, or none to disable TLS.
srv_resolver: The SRV resolver to use srv_resolver: The SRV resolver to use
parsed_uri: The parsed URI that we're wanting to connect to. parsed_uri: The parsed URI that we're wanting to connect to.
Raises:
ValueError if the environment variables contain an invalid proxy specification.
RuntimeError if no tls_options_factory is given for a https connection
""" """
def __init__( def __init__(
self, self,
reactor: IReactorCore, reactor: IReactorCore,
proxy_reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS], tls_client_options_factory: Optional[FederationPolicyForHTTPS],
srv_resolver: SrvResolver, srv_resolver: SrvResolver,
parsed_uri: URI, parsed_uri: URI,
): ):
self._reactor = reactor self._reactor = reactor
self._parsed_uri = parsed_uri self._parsed_uri = parsed_uri
# http_proxy is not needed because federation is always over TLS
proxies = getproxies_environment()
https_proxy = proxies["https"].encode() if "https" in proxies else None
self.no_proxy = proxies["no"] if "no" in proxies else None
# endpoint and credentials to use to connect to the outbound https proxy, if any.
(
self._https_proxy_endpoint,
self._https_proxy_creds,
) = proxyagent.http_proxy_endpoint(
https_proxy,
proxy_reactor,
tls_client_options_factory,
)
# set up the TLS connection params # set up the TLS connection params
# #
# XXX disabling TLS is really only supported here for the benefit of the # XXX disabling TLS is really only supported here for the benefit of the
@ -273,9 +323,33 @@ class MatrixHostnameEndpoint:
host = server.host host = server.host
port = server.port port = server.port
should_skip_proxy = False
if self.no_proxy is not None:
should_skip_proxy = proxy_bypass_environment(
host.decode(),
proxies={"no": self.no_proxy},
)
endpoint: IStreamClientEndpoint
try: try:
logger.debug("Connecting to %s:%i", host.decode("ascii"), port) if self._https_proxy_endpoint and not should_skip_proxy:
endpoint = HostnameEndpoint(self._reactor, host, port) logger.debug(
"Connecting to %s:%i via %s",
host.decode("ascii"),
port,
self._https_proxy_endpoint,
)
endpoint = HTTPConnectProxyEndpoint(
self._reactor,
self._https_proxy_endpoint,
host,
port,
proxy_creds=self._https_proxy_creds,
)
else:
logger.debug("Connecting to %s:%i", host.decode("ascii"), port)
# not using a proxy
endpoint = HostnameEndpoint(self._reactor, host, port)
if self._tls_options: if self._tls_options:
endpoint = wrapClientTLS(self._tls_options, endpoint) endpoint = wrapClientTLS(self._tls_options, endpoint)
result = await make_deferred_yieldable( result = await make_deferred_yieldable(

View File

@ -59,7 +59,6 @@ from synapse.api.errors import (
from synapse.http import QuieterFileBodyProducer from synapse.http import QuieterFileBodyProducer
from synapse.http.client import ( from synapse.http.client import (
BlacklistingAgentWrapper, BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
BodyExceededMaxSize, BodyExceededMaxSize,
ByteWriteable, ByteWriteable,
encode_query_args, encode_query_args,
@ -69,7 +68,7 @@ from synapse.http.federation.matrix_federation_agent import MatrixFederationAgen
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor, JsonDict from synapse.types import JsonDict
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -325,13 +324,7 @@ class MatrixFederationHttpClient:
self.signing_key = hs.signing_key self.signing_key = hs.signing_key
self.server_name = hs.hostname self.server_name = hs.hostname
# We need to use a DNS resolver which filters out blacklisted IP self.reactor = hs.get_reactor()
# addresses, to prevent DNS rebinding.
self.reactor: ISynapseReactor = BlacklistingReactorWrapper(
hs.get_reactor(),
hs.config.federation_ip_range_whitelist,
hs.config.federation_ip_range_blacklist,
)
user_agent = hs.version_string user_agent = hs.version_string
if hs.config.user_agent_suffix: if hs.config.user_agent_suffix:
@ -342,6 +335,7 @@ class MatrixFederationHttpClient:
self.reactor, self.reactor,
tls_client_options_factory, tls_client_options_factory,
user_agent, user_agent,
hs.config.federation_ip_range_whitelist,
hs.config.federation_ip_range_blacklist, hs.config.federation_ip_range_blacklist,
) )

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64
import logging import logging
import re import re
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
@ -21,7 +20,6 @@ from urllib.request import ( # type: ignore[attr-defined]
proxy_bypass_environment, proxy_bypass_environment,
) )
import attr
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
@ -38,7 +36,7 @@ from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials
from synapse.types import ISynapseReactor from synapse.types import ISynapseReactor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,22 +44,6 @@ logger = logging.getLogger(__name__)
_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") _VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
@attr.s
class ProxyCredentials:
username_password = attr.ib(type=bytes)
def as_proxy_authorization_value(self) -> bytes:
"""
Return the value for a Proxy-Authorization header (i.e. 'Basic abdef==').
Returns:
A transformation of the authentication string the encoded value for
a Proxy-Authorization header.
"""
# Encode as base64 and prepend the authorization type
return b"Basic " + base64.encodebytes(self.username_password)
@implementer(IAgent) @implementer(IAgent)
class ProxyAgent(_AgentBase): class ProxyAgent(_AgentBase):
"""An Agent implementation which will use an HTTP proxy if one was requested """An Agent implementation which will use an HTTP proxy if one was requested
@ -95,6 +77,7 @@ class ProxyAgent(_AgentBase):
Raises: Raises:
ValueError if use_proxy is set and the environment variables ValueError if use_proxy is set and the environment variables
contain an invalid proxy specification. contain an invalid proxy specification.
RuntimeError if no tls_options_factory is given for a https connection
""" """
def __init__( def __init__(
@ -131,11 +114,11 @@ class ProxyAgent(_AgentBase):
https_proxy = proxies["https"].encode() if "https" in proxies else None https_proxy = proxies["https"].encode() if "https" in proxies else None
no_proxy = proxies["no"] if "no" in proxies else None no_proxy = proxies["no"] if "no" in proxies else None
self.http_proxy_endpoint, self.http_proxy_creds = _http_proxy_endpoint( self.http_proxy_endpoint, self.http_proxy_creds = http_proxy_endpoint(
http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs
) )
self.https_proxy_endpoint, self.https_proxy_creds = _http_proxy_endpoint( self.https_proxy_endpoint, self.https_proxy_creds = http_proxy_endpoint(
https_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs https_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs
) )
@ -224,22 +207,12 @@ class ProxyAgent(_AgentBase):
and self.https_proxy_endpoint and self.https_proxy_endpoint
and not should_skip_proxy and not should_skip_proxy
): ):
connect_headers = Headers()
# Determine whether we need to set Proxy-Authorization headers
if self.https_proxy_creds:
# Set a Proxy-Authorization header
connect_headers.addRawHeader(
b"Proxy-Authorization",
self.https_proxy_creds.as_proxy_authorization_value(),
)
endpoint = HTTPConnectProxyEndpoint( endpoint = HTTPConnectProxyEndpoint(
self.proxy_reactor, self.proxy_reactor,
self.https_proxy_endpoint, self.https_proxy_endpoint,
parsed_uri.host, parsed_uri.host,
parsed_uri.port, parsed_uri.port,
headers=connect_headers, self.https_proxy_creds,
) )
else: else:
# not using a proxy # not using a proxy
@ -268,10 +241,10 @@ class ProxyAgent(_AgentBase):
) )
def _http_proxy_endpoint( def http_proxy_endpoint(
proxy: Optional[bytes], proxy: Optional[bytes],
reactor: IReactorCore, reactor: IReactorCore,
tls_options_factory: IPolicyForHTTPS, tls_options_factory: Optional[IPolicyForHTTPS],
**kwargs, **kwargs,
) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]: ) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
"""Parses an http proxy setting and returns an endpoint for the proxy """Parses an http proxy setting and returns an endpoint for the proxy
@ -294,6 +267,7 @@ def _http_proxy_endpoint(
Raise: Raise:
ValueError if proxy has no hostname or unsupported scheme. ValueError if proxy has no hostname or unsupported scheme.
RuntimeError if no tls_options_factory is given for a https connection
""" """
if proxy is None: if proxy is None:
return None, None return None, None
@ -305,8 +279,13 @@ def _http_proxy_endpoint(
proxy_endpoint = HostnameEndpoint(reactor, host, port, **kwargs) proxy_endpoint = HostnameEndpoint(reactor, host, port, **kwargs)
if scheme == b"https": if scheme == b"https":
tls_options = tls_options_factory.creatorForNetloc(host, port) if tls_options_factory:
proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint) tls_options = tls_options_factory.creatorForNetloc(host, port)
proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint)
else:
raise RuntimeError(
f"No TLS options for a https connection via proxy {proxy!s}"
)
return proxy_endpoint, credentials return proxy_endpoint, credentials

View File

@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64
import logging import logging
from typing import Optional import os
from unittest.mock import Mock from typing import Iterable, Optional
from unittest.mock import Mock, patch
import treq import treq
from netaddr import IPSet from netaddr import IPSet
@ -22,11 +24,12 @@ from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions
from twisted.internet.interfaces import IProtocolFactory
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent from twisted.web.client import Agent
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel, Request
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IPolicyForHTTPS from twisted.web.iweb import IPolicyForHTTPS
@ -49,24 +52,6 @@ from tests.utils import default_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
test_server_connection_factory = None
def get_connection_factory():
# this needs to happen once, but not until we are ready to run the first test
global test_server_connection_factory
if test_server_connection_factory is None:
test_server_connection_factory = TestServerTLSConnectionFactory(
sanlist=[
b"DNS:testserv",
b"DNS:target-server",
b"DNS:xn--bcher-kva.com",
b"IP:1.2.3.4",
b"IP:::1",
]
)
return test_server_connection_factory
# Once Async Mocks or lambdas are supported this can go away. # Once Async Mocks or lambdas are supported this can go away.
def generate_resolve_service(result): def generate_resolve_service(result):
@ -100,24 +85,38 @@ class MatrixFederationAgentTests(unittest.TestCase):
had_well_known_cache=self.had_well_known_cache, had_well_known_cache=self.had_well_known_cache,
) )
self.agent = MatrixFederationAgent( def _make_connection(
reactor=self.reactor, self,
tls_client_options_factory=self.tls_factory, client_factory: IProtocolFactory,
user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. ssl: bool = True,
ip_blacklist=IPSet(), expected_sni: bytes = None,
_srv_resolver=self.mock_resolver, tls_sanlist: Optional[Iterable[bytes]] = None,
_well_known_resolver=self.well_known_resolver, ) -> HTTPChannel:
)
def _make_connection(self, client_factory, expected_sni):
"""Builds a test server, and completes the outgoing client connection """Builds a test server, and completes the outgoing client connection
Args:
client_factory: the the factory that the
application is trying to use to make the outbound connection. We will
invoke it to build the client Protocol
ssl: If true, we will expect an ssl connection and wrap
server_factory with a TLSMemoryBIOFactory
False is set only for when proxy expect http connection.
Otherwise federation requests use always https.
expected_sni: the expected SNI value
tls_sanlist: list of SAN entries for the TLS cert presented by the server.
Returns: Returns:
HTTPChannel: the test server the server Protocol returned by server_factory
""" """
# build the test server # build the test server
server_tls_protocol = _build_test_server(get_connection_factory()) server_factory = _get_test_protocol_factory()
if ssl:
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
server_protocol = server_factory.buildProtocol(None)
# now, tell the client protocol factory to build the client protocol (it will be a # now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
@ -128,35 +127,39 @@ class MatrixFederationAgentTests(unittest.TestCase):
# stubbing that out here. # stubbing that out here.
client_protocol = client_factory.buildProtocol(None) client_protocol = client_factory.buildProtocol(None)
client_protocol.makeConnection( client_protocol.makeConnection(
FakeTransport(server_tls_protocol, self.reactor, client_protocol) FakeTransport(server_protocol, self.reactor, client_protocol)
) )
# tell the server tls protocol to send its stuff back to the client, too # tell the server protocol to send its stuff back to the client, too
server_tls_protocol.makeConnection( server_protocol.makeConnection(
FakeTransport(client_protocol, self.reactor, server_tls_protocol) FakeTransport(client_protocol, self.reactor, server_protocol)
) )
# grab a hold of the TLS connection, in case it gets torn down if ssl:
server_tls_connection = server_tls_protocol._tlsConnection # fish the test server back out of the server-side TLS protocol.
http_protocol = server_protocol.wrappedProtocol
# grab a hold of the TLS connection, in case it gets torn down
tls_connection = server_protocol._tlsConnection
else:
http_protocol = server_protocol
tls_connection = None
# fish the test server back out of the server-side TLS protocol. # give the reactor a pump to get the TLS juices flowing (if needed)
http_protocol = server_tls_protocol.wrappedProtocol self.reactor.advance(0)
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
# check the SNI # check the SNI
server_name = server_tls_connection.get_servername() if expected_sni is not None:
self.assertEqual( server_name = tls_connection.get_servername()
server_name, self.assertEqual(
expected_sni, server_name,
"Expected SNI %s but got %s" % (expected_sni, server_name), expected_sni,
) f"Expected SNI {expected_sni!s} but got {server_name!s}",
)
return http_protocol return http_protocol
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_get_request(self, uri): def _make_get_request(self, uri: bytes):
""" """
Sends a simple GET request via the agent, and checks its logcontext management Sends a simple GET request via the agent, and checks its logcontext management
""" """
@ -180,20 +183,20 @@ class MatrixFederationAgentTests(unittest.TestCase):
def _handle_well_known_connection( def _handle_well_known_connection(
self, self,
client_factory, client_factory: IProtocolFactory,
expected_sni, expected_sni: bytes,
content, content: bytes,
response_headers: Optional[dict] = None, response_headers: Optional[dict] = None,
): ) -> HTTPChannel:
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the """Handle an outgoing HTTPs connection: wire it up to a server, check that the
request is for a .well-known, and send the response. request is for a .well-known, and send the response.
Args: Args:
client_factory (IProtocolFactory): outgoing connection client_factory: outgoing connection
expected_sni (bytes): SNI that we expect the outgoing connection to send expected_sni: SNI that we expect the outgoing connection to send
content (bytes): content to send back as the .well-known content: content to send back as the .well-known
Returns: Returns:
HTTPChannel: server impl server impl
""" """
# make the connection for .well-known # make the connection for .well-known
well_known_server = self._make_connection( well_known_server = self._make_connection(
@ -209,7 +212,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
return well_known_server return well_known_server
def _send_well_known_response( def _send_well_known_response(
self, request, content, headers: Optional[dict] = None self,
request: Request,
content: bytes,
headers: Optional[dict] = None,
): ):
"""Check that an incoming request looks like a valid .well-known request, and """Check that an incoming request looks like a valid .well-known request, and
send back the response. send back the response.
@ -225,10 +231,37 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,)) self.reactor.pump((0.1,))
def _make_agent(self) -> MatrixFederationAgent:
"""
If a proxy server is set, the MatrixFederationAgent must be created again
because it is created too early during setUp
"""
return MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=self.tls_factory,
user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
ip_whitelist=IPSet(),
ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver,
)
def test_get(self): def test_get(self):
""" """happy-path test of a GET request with an explicit port"""
happy-path test of a GET request with an explicit port self._do_get()
"""
@patch.dict(
os.environ,
{"https_proxy": "proxy.com", "no_proxy": "testserv"},
)
def test_get_bypass_proxy(self):
"""test of a GET request with an explicit port and bypass proxy"""
self._do_get()
def _do_get(self):
"""test of a GET request with an explicit port"""
self.agent = self._make_agent()
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
@ -282,10 +315,188 @@ class MatrixFederationAgentTests(unittest.TestCase):
json = self.successResultOf(treq.json_content(response)) json = self.successResultOf(treq.json_content(response))
self.assertEqual(json, {"a": 1}) self.assertEqual(json, {"a": 1})
@patch.dict(
os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"}
)
def test_get_via_http_proxy(self):
"""test for federation request through a http proxy"""
self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None)
@patch.dict(
os.environ,
{"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"},
)
def test_get_via_http_proxy_with_auth(self):
"""test for federation request through a http proxy with authentication"""
self._do_get_via_proxy(
expect_proxy_ssl=False, expected_auth_credentials=b"user:pass"
)
@patch.dict(
os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
)
def test_get_via_https_proxy(self):
"""test for federation request through a https proxy"""
self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None)
@patch.dict(
os.environ,
{"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"},
)
def test_get_via_https_proxy_with_auth(self):
"""test for federation request through a https proxy with authentication"""
self._do_get_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=b"user:pass"
)
def _do_get_via_proxy(
self,
expect_proxy_ssl: bool = False,
expected_auth_credentials: Optional[bytes] = None,
):
"""Send a https federation request via an agent and check that it is correctly
received at the proxy and client. The proxy can use either http or https.
Args:
expect_proxy_ssl: True if we expect the request to connect to the proxy via https.
expected_auth_credentials: credentials we expect to be presented to authenticate at the proxy
"""
self.agent = self._make_agent()
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["proxy.com"] = "9.9.9.9"
test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
# make sure we are connecting to the proxy
self.assertEqual(host, "9.9.9.9")
self.assertEqual(port, 1080)
# make a test server to act as the proxy, and wire up the client
proxy_server = self._make_connection(
client_factory,
ssl=expect_proxy_ssl,
tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None,
expected_sni=b"proxy.com" if expect_proxy_ssl else None,
)
assert isinstance(proxy_server, HTTPChannel)
# now there should be a pending CONNECT request
self.assertEqual(len(proxy_server.requests), 1)
request = proxy_server.requests[0]
self.assertEqual(request.method, b"CONNECT")
self.assertEqual(request.path, b"testserv:8448")
# Check whether auth credentials have been supplied to the proxy
proxy_auth_header_values = request.requestHeaders.getRawHeaders(
b"Proxy-Authorization"
)
if expected_auth_credentials is not None:
# Compute the correct header value for Proxy-Authorization
encoded_credentials = base64.b64encode(expected_auth_credentials)
expected_header_value = b"Basic " + encoded_credentials
# Validate the header's value
self.assertIn(expected_header_value, proxy_auth_header_values)
else:
# Check that the Proxy-Authorization header has not been supplied to the proxy
self.assertIsNone(proxy_auth_header_values)
# tell the proxy server not to close the connection
proxy_server.persistent = True
request.finish()
# now we make another test server to act as the upstream HTTP server.
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
).buildProtocol(None)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
proxy_server_transport = proxy_server.transport
server_ssl_protocol.makeConnection(proxy_server_transport)
# ... and replace the protocol on the proxy's transport with the
# TLSMemoryBIOProtocol for the test server, so that incoming traffic
# to the proxy gets sent over to the HTTP(s) server.
# See also comment at `_do_https_request_via_proxy`
# in ../test_proxyagent.py for more details
if expect_proxy_ssl:
assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol)
proxy_server_transport.wrappedProtocol = server_ssl_protocol
else:
assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other
c2s_transport = client_protocol.transport
c2s_transport.other = server_ssl_protocol
self.reactor.advance(0)
server_name = server_ssl_protocol._tlsConnection.get_servername()
expected_sni = b"testserv"
self.assertEqual(
server_name,
expected_sni,
f"Expected SNI {expected_sni!s} but got {server_name!s}",
)
# now there should be a pending request
http_server = server_ssl_protocol.wrappedProtocol
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b"GET")
self.assertEqual(request.path, b"/foo/bar")
self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"]
)
self.assertEqual(
request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
)
# Check that the destination server DID NOT receive proxy credentials
self.assertIsNone(request.requestHeaders.getRawHeaders(b"Proxy-Authorization"))
content = request.content.read()
self.assertEqual(content, b"")
# Deferred is still without a result
self.assertNoResult(test_d)
# send the headers
request.responseHeaders.setRawHeaders(b"Content-Type", [b"application/json"])
request.write("")
self.reactor.pump((0.1,))
response = self.successResultOf(test_d)
# that should give us a Response object
self.assertEqual(response.code, 200)
# Send the body
request.write('{ "a": 1 }'.encode("ascii"))
request.finish()
self.reactor.pump((0.1,))
# check it can be read
json = self.successResultOf(treq.json_content(response))
self.assertEqual(json, {"a": 1})
def test_get_ip_address(self): def test_get_ip_address(self):
""" """
Test the behaviour when the server name contains an explicit IP (with no port) Test the behaviour when the server name contains an explicit IP (with no port)
""" """
self.agent = self._make_agent()
# there will be a getaddrinfo on the IP # there will be a getaddrinfo on the IP
self.reactor.lookups["1.2.3.4"] = "1.2.3.4" self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
@ -320,6 +531,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name contains an explicit IPv6 address Test the behaviour when the server name contains an explicit IPv6 address
(with no port) (with no port)
""" """
self.agent = self._make_agent()
# there will be a getaddrinfo on the IP # there will be a getaddrinfo on the IP
self.reactor.lookups["::1"] = "::1" self.reactor.lookups["::1"] = "::1"
@ -355,6 +567,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name contains an explicit IPv6 address Test the behaviour when the server name contains an explicit IPv6 address
(with explicit port) (with explicit port)
""" """
self.agent = self._make_agent()
# there will be a getaddrinfo on the IP # there will be a getaddrinfo on the IP
self.reactor.lookups["::1"] = "::1" self.reactor.lookups["::1"] = "::1"
@ -389,6 +602,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
""" """
Test the behaviour when the certificate on the server doesn't match the hostname Test the behaviour when the certificate on the server doesn't match the hostname
""" """
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv1"] = "1.2.3.4" self.reactor.lookups["testserv1"] = "1.2.3.4"
@ -441,6 +656,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name contains an explicit IP, but Test the behaviour when the server name contains an explicit IP, but
the server cert doesn't cover it the server cert doesn't cover it
""" """
self.agent = self._make_agent()
# there will be a getaddrinfo on the IP # there will be a getaddrinfo on the IP
self.reactor.lookups["1.2.3.5"] = "1.2.3.5" self.reactor.lookups["1.2.3.5"] = "1.2.3.5"
@ -471,6 +688,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
""" """
Test the behaviour when the server name has no port, no SRV, and no well-known Test the behaviour when the server name has no port, no SRV, and no well-known
""" """
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
@ -524,6 +742,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_get_well_known(self): def test_get_well_known(self):
"""Test the behaviour when the .well-known delegates elsewhere""" """Test the behaviour when the .well-known delegates elsewhere"""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
@ -587,6 +806,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the server name has no port and no SRV record, but """Test the behaviour when the server name has no port and no SRV record, but
the .well-known has a 300 redirect the .well-known has a 300 redirect
""" """
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f" self.reactor.lookups["target-server"] = "1::f"
@ -675,6 +896,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
""" """
Test the behaviour when the server name has an *invalid* well-known (and no SRV) Test the behaviour when the server name has an *invalid* well-known (and no SRV)
""" """
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
@ -743,6 +965,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
reactor=self.reactor, reactor=self.reactor,
tls_client_options_factory=tls_factory, tls_client_options_factory=tls_factory,
user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below.
ip_whitelist=IPSet(),
ip_blacklist=IPSet(), ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver, _srv_resolver=self.mock_resolver,
_well_known_resolver=WellKnownResolver( _well_known_resolver=WellKnownResolver(
@ -780,6 +1003,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
""" """
Test the behaviour when there is a single SRV record Test the behaviour when there is a single SRV record
""" """
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service( self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[Server(host=b"srvtarget", port=8443)] [Server(host=b"srvtarget", port=8443)]
) )
@ -820,6 +1045,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the .well-known redirects to a place where there """Test the behaviour when the .well-known redirects to a place where there
is a SRV. is a SRV.
""" """
self.agent = self._make_agent()
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["srvtarget"] = "5.6.7.8" self.reactor.lookups["srvtarget"] = "5.6.7.8"
@ -876,6 +1103,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_servername(self): def test_idna_servername(self):
"""test the behaviour when the server name has idna chars in""" """test the behaviour when the server name has idna chars in"""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
@ -937,6 +1165,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_srv_target(self): def test_idna_srv_target(self):
"""test the behaviour when the target of a SRV record has idna chars""" """test the behaviour when the target of a SRV record has idna chars"""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service( self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com
@ -1140,6 +1369,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_srv_fallbacks(self): def test_srv_fallbacks(self):
"""Test that other SRV results are tried if the first one fails.""" """Test that other SRV results are tried if the first one fails."""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service( self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[ [
Server(host=b"target.com", port=8443), Server(host=b"target.com", port=8443),
@ -1266,34 +1497,49 @@ def _check_logcontext(context):
raise AssertionError("Expected logcontext %s but was %s" % (context, current)) raise AssertionError("Expected logcontext %s but was %s" % (context, current))
def _build_test_server(connection_creator): def _wrap_server_factory_for_tls(
"""Construct a test server factory: IProtocolFactory, sanlist: Iterable[bytes] = None
) -> IProtocolFactory:
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
Args: Args:
connection_creator (IOpenSSLServerConnectionCreator): thing to build factory: protocol factory to wrap
SSL connections sanlist: list of domains the cert should be valid for
sanlist (list[bytes]): list of the SAN entries for the cert returned
by the server
Returns: Returns:
TLSMemoryBIOProtocol interfaces.IProtocolFactory
"""
if sanlist is None:
sanlist = [
b"DNS:testserv",
b"DNS:target-server",
b"DNS:xn--bcher-kva.com",
b"IP:1.2.3.4",
b"IP:::1",
]
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory
)
def _get_test_protocol_factory() -> IProtocolFactory:
"""Get a protocol Factory which will build an HTTPChannel
Returns:
interfaces.IProtocolFactory
""" """
server_factory = Factory.forProtocol(HTTPChannel) server_factory = Factory.forProtocol(HTTPChannel)
# Request.finish expects the factory to have a 'log' method. # Request.finish expects the factory to have a 'log' method.
server_factory.log = _log_request server_factory.log = _log_request
server_tls_factory = TLSMemoryBIOFactory( return server_factory
connection_creator, isClient=False, wrappedFactory=server_factory
)
return server_tls_factory.buildProtocol(None)
def _log_request(request): def _log_request(request: str):
"""Implements Factory.log, which is expected by Request.finish""" """Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request) logger.info(f"Completed request {request}")
@implementer(IPolicyForHTTPS) @implementer(IPolicyForHTTPS)

View File

@ -29,7 +29,8 @@ from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
from synapse.http.client import BlacklistingReactorWrapper from synapse.http.client import BlacklistingReactorWrapper
from synapse.http.proxyagent import ProxyAgent, ProxyCredentials, parse_proxy from synapse.http.connectproxyclient import ProxyCredentials
from synapse.http.proxyagent import ProxyAgent, parse_proxy
from tests.http import TestServerTLSConnectionFactory, get_test_https_policy from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.server import FakeTransport, ThreadedMemoryReactorClock
@ -392,7 +393,9 @@ class MatrixFederationAgentTests(TestCase):
""" """
Tests that requests can be made through a proxy. Tests that requests can be made through a proxy.
""" """
self._do_http_request_via_proxy(ssl=False, auth_credentials=None) self._do_http_request_via_proxy(
expect_proxy_ssl=False, expected_auth_credentials=None
)
@patch.dict( @patch.dict(
os.environ, os.environ,
@ -402,13 +405,17 @@ class MatrixFederationAgentTests(TestCase):
""" """
Tests that authenticated requests can be made through a proxy. Tests that authenticated requests can be made through a proxy.
""" """
self._do_http_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies") self._do_http_request_via_proxy(
expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies"
)
@patch.dict( @patch.dict(
os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"} os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"}
) )
def test_http_request_via_https_proxy(self): def test_http_request_via_https_proxy(self):
self._do_http_request_via_proxy(ssl=True, auth_credentials=None) self._do_http_request_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=None
)
@patch.dict( @patch.dict(
os.environ, os.environ,
@ -418,12 +425,16 @@ class MatrixFederationAgentTests(TestCase):
}, },
) )
def test_http_request_via_https_proxy_with_auth(self): def test_http_request_via_https_proxy_with_auth(self):
self._do_http_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies") self._do_http_request_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
)
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
def test_https_request_via_proxy(self): def test_https_request_via_proxy(self):
"""Tests that TLS-encrypted requests can be made through a proxy""" """Tests that TLS-encrypted requests can be made through a proxy"""
self._do_https_request_via_proxy(ssl=False, auth_credentials=None) self._do_https_request_via_proxy(
expect_proxy_ssl=False, expected_auth_credentials=None
)
@patch.dict( @patch.dict(
os.environ, os.environ,
@ -431,14 +442,18 @@ class MatrixFederationAgentTests(TestCase):
) )
def test_https_request_via_proxy_with_auth(self): def test_https_request_via_proxy_with_auth(self):
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy""" """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
self._do_https_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies") self._do_https_request_via_proxy(
expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies"
)
@patch.dict( @patch.dict(
os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
) )
def test_https_request_via_https_proxy(self): def test_https_request_via_https_proxy(self):
"""Tests that TLS-encrypted requests can be made through a proxy""" """Tests that TLS-encrypted requests can be made through a proxy"""
self._do_https_request_via_proxy(ssl=True, auth_credentials=None) self._do_https_request_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=None
)
@patch.dict( @patch.dict(
os.environ, os.environ,
@ -446,20 +461,22 @@ class MatrixFederationAgentTests(TestCase):
) )
def test_https_request_via_https_proxy_with_auth(self): def test_https_request_via_https_proxy_with_auth(self):
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy""" """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
self._do_https_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies") self._do_https_request_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
)
def _do_http_request_via_proxy( def _do_http_request_via_proxy(
self, self,
ssl: bool = False, expect_proxy_ssl: bool = False,
auth_credentials: Optional[bytes] = None, expected_auth_credentials: Optional[bytes] = None,
): ):
"""Send a http request via an agent and check that it is correctly received at """Send a http request via an agent and check that it is correctly received at
the proxy. The proxy can use either http or https. the proxy. The proxy can use either http or https.
Args: Args:
ssl: True if we expect the request to connect via https to proxy expect_proxy_ssl: True if we expect the request to connect via https to proxy
auth_credentials: credentials to authenticate at proxy expected_auth_credentials: credentials to authenticate at proxy
""" """
if ssl: if expect_proxy_ssl:
agent = ProxyAgent( agent = ProxyAgent(
self.reactor, use_proxy=True, contextFactory=get_test_https_policy() self.reactor, use_proxy=True, contextFactory=get_test_https_policy()
) )
@ -480,9 +497,9 @@ class MatrixFederationAgentTests(TestCase):
http_server = self._make_connection( http_server = self._make_connection(
client_factory, client_factory,
_get_test_protocol_factory(), _get_test_protocol_factory(),
ssl=ssl, ssl=expect_proxy_ssl,
tls_sanlist=[b"DNS:proxy.com"] if ssl else None, tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None,
expected_sni=b"proxy.com" if ssl else None, expected_sni=b"proxy.com" if expect_proxy_ssl else None,
) )
# the FakeTransport is async, so we need to pump the reactor # the FakeTransport is async, so we need to pump the reactor
@ -498,9 +515,9 @@ class MatrixFederationAgentTests(TestCase):
b"Proxy-Authorization" b"Proxy-Authorization"
) )
if auth_credentials is not None: if expected_auth_credentials is not None:
# Compute the correct header value for Proxy-Authorization # Compute the correct header value for Proxy-Authorization
encoded_credentials = base64.b64encode(auth_credentials) encoded_credentials = base64.b64encode(expected_auth_credentials)
expected_header_value = b"Basic " + encoded_credentials expected_header_value = b"Basic " + encoded_credentials
# Validate the header's value # Validate the header's value
@ -523,14 +540,14 @@ class MatrixFederationAgentTests(TestCase):
def _do_https_request_via_proxy( def _do_https_request_via_proxy(
self, self,
ssl: bool = False, expect_proxy_ssl: bool = False,
auth_credentials: Optional[bytes] = None, expected_auth_credentials: Optional[bytes] = None,
): ):
"""Send a https request via an agent and check that it is correctly received at """Send a https request via an agent and check that it is correctly received at
the proxy and client. The proxy can use either http or https. the proxy and client. The proxy can use either http or https.
Args: Args:
ssl: True if we expect the request to connect via https to proxy expect_proxy_ssl: True if we expect the request to connect via https to proxy
auth_credentials: credentials to authenticate at proxy expected_auth_credentials: credentials to authenticate at proxy
""" """
agent = ProxyAgent( agent = ProxyAgent(
self.reactor, self.reactor,
@ -552,9 +569,9 @@ class MatrixFederationAgentTests(TestCase):
proxy_server = self._make_connection( proxy_server = self._make_connection(
client_factory, client_factory,
_get_test_protocol_factory(), _get_test_protocol_factory(),
ssl=ssl, ssl=expect_proxy_ssl,
tls_sanlist=[b"DNS:proxy.com"] if ssl else None, tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None,
expected_sni=b"proxy.com" if ssl else None, expected_sni=b"proxy.com" if expect_proxy_ssl else None,
) )
assert isinstance(proxy_server, HTTPChannel) assert isinstance(proxy_server, HTTPChannel)
@ -570,9 +587,9 @@ class MatrixFederationAgentTests(TestCase):
b"Proxy-Authorization" b"Proxy-Authorization"
) )
if auth_credentials is not None: if expected_auth_credentials is not None:
# Compute the correct header value for Proxy-Authorization # Compute the correct header value for Proxy-Authorization
encoded_credentials = base64.b64encode(auth_credentials) encoded_credentials = base64.b64encode(expected_auth_credentials)
expected_header_value = b"Basic " + encoded_credentials expected_header_value = b"Basic " + encoded_credentials
# Validate the header's value # Validate the header's value
@ -606,7 +623,7 @@ class MatrixFederationAgentTests(TestCase):
# Protocol to implement the proxy, which starts out by forwarding to an # Protocol to implement the proxy, which starts out by forwarding to an
# HTTPChannel (to implement the CONNECT command) and can then be switched # HTTPChannel (to implement the CONNECT command) and can then be switched
# into a mode where it forwards its traffic to another Protocol.) # into a mode where it forwards its traffic to another Protocol.)
if ssl: if expect_proxy_ssl:
assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol) assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol)
proxy_server_transport.wrappedProtocol = server_ssl_protocol proxy_server_transport.wrappedProtocol = server_ssl_protocol
else: else: