Merge pull request #5864 from matrix-org/erikj/reliable_lookups
Refactor MatrixFederationAgent to retry SRV.pull/5919/head
commit
dfd10f5133
|
@ -0,0 +1 @@
|
||||||
|
Correctly retry all hosts returned from SRV when we fail to connect.
|
|
@ -14,21 +14,21 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import urllib
|
||||||
|
|
||||||
import attr
|
from netaddr import AddrFormatError, IPAddress
|
||||||
from netaddr import IPAddress
|
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||||
from twisted.internet.interfaces import IStreamClientEndpoint
|
from twisted.internet.interfaces import IStreamClientEndpoint
|
||||||
from twisted.web.client import URI, Agent, HTTPConnectionPool
|
from twisted.web.client import Agent, HTTPConnectionPool
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
from twisted.web.iweb import IAgent
|
from twisted.web.iweb import IAgent, IAgentEndpointFactory
|
||||||
|
|
||||||
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
|
from synapse.http.federation.srv_resolver import Server, SrvResolver
|
||||||
from synapse.http.federation.well_known_resolver import WellKnownResolver
|
from synapse.http.federation.well_known_resolver import WellKnownResolver
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -36,8 +36,9 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@implementer(IAgent)
|
@implementer(IAgent)
|
||||||
class MatrixFederationAgent(object):
|
class MatrixFederationAgent(object):
|
||||||
"""An Agent-like thing which provides a `request` method which will look up a matrix
|
"""An Agent-like thing which provides a `request` method which correctly
|
||||||
server and send an HTTP request to it.
|
handles resolving matrix server names when using matrix://. Handles standard
|
||||||
|
https URIs as normal.
|
||||||
|
|
||||||
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
|
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
|
||||||
|
|
||||||
|
@ -65,17 +66,19 @@ class MatrixFederationAgent(object):
|
||||||
):
|
):
|
||||||
self._reactor = reactor
|
self._reactor = reactor
|
||||||
self._clock = Clock(reactor)
|
self._clock = Clock(reactor)
|
||||||
|
|
||||||
self._tls_client_options_factory = tls_client_options_factory
|
|
||||||
if _srv_resolver is None:
|
|
||||||
_srv_resolver = SrvResolver()
|
|
||||||
self._srv_resolver = _srv_resolver
|
|
||||||
|
|
||||||
self._pool = HTTPConnectionPool(reactor)
|
self._pool = HTTPConnectionPool(reactor)
|
||||||
self._pool.retryAutomatically = False
|
self._pool.retryAutomatically = False
|
||||||
self._pool.maxPersistentPerHost = 5
|
self._pool.maxPersistentPerHost = 5
|
||||||
self._pool.cachedConnectionTimeout = 2 * 60
|
self._pool.cachedConnectionTimeout = 2 * 60
|
||||||
|
|
||||||
|
self._agent = Agent.usingEndpointFactory(
|
||||||
|
self._reactor,
|
||||||
|
MatrixHostnameEndpointFactory(
|
||||||
|
reactor, tls_client_options_factory, _srv_resolver
|
||||||
|
),
|
||||||
|
pool=self._pool,
|
||||||
|
)
|
||||||
|
|
||||||
if _well_known_resolver is None:
|
if _well_known_resolver is None:
|
||||||
_well_known_resolver = WellKnownResolver(
|
_well_known_resolver = WellKnownResolver(
|
||||||
self._reactor,
|
self._reactor,
|
||||||
|
@ -93,19 +96,15 @@ class MatrixFederationAgent(object):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
method (bytes): HTTP method: GET/POST/etc
|
method (bytes): HTTP method: GET/POST/etc
|
||||||
|
|
||||||
uri (bytes): Absolute URI to be retrieved
|
uri (bytes): Absolute URI to be retrieved
|
||||||
|
|
||||||
headers (twisted.web.http_headers.Headers|None):
|
headers (twisted.web.http_headers.Headers|None):
|
||||||
HTTP headers to send with the request, or None to
|
HTTP headers to send with the request, or None to
|
||||||
send no extra headers.
|
send no extra headers.
|
||||||
|
|
||||||
bodyProducer (twisted.web.iweb.IBodyProducer|None):
|
bodyProducer (twisted.web.iweb.IBodyProducer|None):
|
||||||
An object which can generate bytes to make up the
|
An object which can generate bytes to make up the
|
||||||
body of this request (for example, the properly encoded contents of
|
body of this request (for example, the properly encoded contents of
|
||||||
a file for a file upload). Or None if the request is to have
|
a file for a file upload). Or None if the request is to have
|
||||||
no body.
|
no body.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[twisted.web.iweb.IResponse]:
|
Deferred[twisted.web.iweb.IResponse]:
|
||||||
fires when the header of the response has been received (regardless of the
|
fires when the header of the response has been received (regardless of the
|
||||||
|
@ -113,210 +112,207 @@ class MatrixFederationAgent(object):
|
||||||
response from being received (including problems that prevent the request
|
response from being received (including problems that prevent the request
|
||||||
from being sent).
|
from being sent).
|
||||||
"""
|
"""
|
||||||
parsed_uri = URI.fromBytes(uri, defaultPort=-1)
|
# We use urlparse as that will set `port` to None if there is no
|
||||||
res = yield self._route_matrix_uri(parsed_uri)
|
# explicit port.
|
||||||
|
parsed_uri = urllib.parse.urlparse(uri)
|
||||||
|
|
||||||
# set up the TLS connection params
|
# If this is a matrix:// URI check if the server has delegated matrix
|
||||||
|
# traffic using well-known delegation.
|
||||||
#
|
#
|
||||||
# XXX disabling TLS is really only supported here for the benefit of the
|
# We have to do this here and not in the endpoint as we need to rewrite
|
||||||
# unit tests. We should make the UTs cope with TLS rather than having to make
|
# the host header with the delegated server name.
|
||||||
# the code support the unit tests.
|
delegated_server = None
|
||||||
if self._tls_client_options_factory is None:
|
if (
|
||||||
tls_options = None
|
parsed_uri.scheme == b"matrix"
|
||||||
else:
|
and not _is_ip_literal(parsed_uri.hostname)
|
||||||
tls_options = self._tls_client_options_factory.get_options(
|
and not parsed_uri.port
|
||||||
res.tls_server_name.decode("ascii")
|
):
|
||||||
|
well_known_result = yield self._well_known_resolver.get_well_known(
|
||||||
|
parsed_uri.hostname
|
||||||
)
|
)
|
||||||
|
delegated_server = well_known_result.delegated_server
|
||||||
|
|
||||||
# make sure that the Host header is set correctly
|
if delegated_server:
|
||||||
|
# Ok, the server has delegated matrix traffic to somewhere else, so
|
||||||
|
# lets rewrite the URL to replace the server with the delegated
|
||||||
|
# server name.
|
||||||
|
uri = urllib.parse.urlunparse(
|
||||||
|
(
|
||||||
|
parsed_uri.scheme,
|
||||||
|
delegated_server,
|
||||||
|
parsed_uri.path,
|
||||||
|
parsed_uri.params,
|
||||||
|
parsed_uri.query,
|
||||||
|
parsed_uri.fragment,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
parsed_uri = urllib.parse.urlparse(uri)
|
||||||
|
|
||||||
|
# We need to make sure the host header is set to the netloc of the
|
||||||
|
# server.
|
||||||
if headers is None:
|
if headers is None:
|
||||||
headers = Headers()
|
headers = Headers()
|
||||||
else:
|
else:
|
||||||
headers = headers.copy()
|
headers = headers.copy()
|
||||||
|
|
||||||
if not headers.hasHeader(b"host"):
|
if not headers.hasHeader(b"host"):
|
||||||
headers.addRawHeader(b"host", res.host_header)
|
headers.addRawHeader(b"host", parsed_uri.netloc)
|
||||||
|
|
||||||
class EndpointFactory(object):
|
|
||||||
@staticmethod
|
|
||||||
def endpointForURI(_uri):
|
|
||||||
ep = LoggingHostnameEndpoint(
|
|
||||||
self._reactor, res.target_host, res.target_port
|
|
||||||
)
|
|
||||||
if tls_options is not None:
|
|
||||||
ep = wrapClientTLS(tls_options, ep)
|
|
||||||
return ep
|
|
||||||
|
|
||||||
agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
|
|
||||||
res = yield make_deferred_yieldable(
|
res = yield make_deferred_yieldable(
|
||||||
agent.request(method, uri, headers, bodyProducer)
|
self._agent.request(method, uri, headers, bodyProducer)
|
||||||
)
|
)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
|
|
||||||
"""Helper for `request`: determine the routing for a Matrix URI
|
|
||||||
|
|
||||||
Args:
|
@implementer(IAgentEndpointFactory)
|
||||||
parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
|
class MatrixHostnameEndpointFactory(object):
|
||||||
parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
|
"""Factory for MatrixHostnameEndpoint for parsing to an Agent.
|
||||||
if there is no explicit port given.
|
|
||||||
|
|
||||||
lookup_well_known (bool): True if we should look up the .well-known file if
|
|
||||||
there is no SRV record.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred[_RoutingResult]
|
|
||||||
"""
|
"""
|
||||||
# check for an IP literal
|
|
||||||
try:
|
|
||||||
ip_address = IPAddress(parsed_uri.host.decode("ascii"))
|
|
||||||
except Exception:
|
|
||||||
# not an IP address
|
|
||||||
ip_address = None
|
|
||||||
|
|
||||||
if ip_address:
|
def __init__(self, reactor, tls_client_options_factory, srv_resolver):
|
||||||
port = parsed_uri.port
|
self._reactor = reactor
|
||||||
if port == -1:
|
self._tls_client_options_factory = tls_client_options_factory
|
||||||
port = 8448
|
|
||||||
return _RoutingResult(
|
|
||||||
host_header=parsed_uri.netloc,
|
|
||||||
tls_server_name=parsed_uri.host,
|
|
||||||
target_host=parsed_uri.host,
|
|
||||||
target_port=port,
|
|
||||||
)
|
|
||||||
|
|
||||||
if parsed_uri.port != -1:
|
if srv_resolver is None:
|
||||||
# there is an explicit port
|
srv_resolver = SrvResolver()
|
||||||
return _RoutingResult(
|
|
||||||
host_header=parsed_uri.netloc,
|
|
||||||
tls_server_name=parsed_uri.host,
|
|
||||||
target_host=parsed_uri.host,
|
|
||||||
target_port=parsed_uri.port,
|
|
||||||
)
|
|
||||||
|
|
||||||
if lookup_well_known:
|
self._srv_resolver = srv_resolver
|
||||||
# try a .well-known lookup
|
|
||||||
well_known_result = yield self._well_known_resolver.get_well_known(
|
|
||||||
parsed_uri.host
|
|
||||||
)
|
|
||||||
well_known_server = well_known_result.delegated_server
|
|
||||||
|
|
||||||
if well_known_server:
|
def endpointForURI(self, parsed_uri):
|
||||||
# if we found a .well-known, start again, but don't do another
|
return MatrixHostnameEndpoint(
|
||||||
# .well-known lookup.
|
self._reactor,
|
||||||
|
self._tls_client_options_factory,
|
||||||
# parse the server name in the .well-known response into host/port.
|
self._srv_resolver,
|
||||||
# (This code is lifted from twisted.web.client.URI.fromBytes).
|
parsed_uri,
|
||||||
if b":" in well_known_server:
|
|
||||||
well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
|
|
||||||
try:
|
|
||||||
well_known_port = int(well_known_port)
|
|
||||||
except ValueError:
|
|
||||||
# the part after the colon could not be parsed as an int
|
|
||||||
# - we assume it is an IPv6 literal with no port (the closing
|
|
||||||
# ']' stops it being parsed as an int)
|
|
||||||
well_known_host, well_known_port = well_known_server, -1
|
|
||||||
else:
|
|
||||||
well_known_host, well_known_port = well_known_server, -1
|
|
||||||
|
|
||||||
new_uri = URI(
|
|
||||||
scheme=parsed_uri.scheme,
|
|
||||||
netloc=well_known_server,
|
|
||||||
host=well_known_host,
|
|
||||||
port=well_known_port,
|
|
||||||
path=parsed_uri.path,
|
|
||||||
params=parsed_uri.params,
|
|
||||||
query=parsed_uri.query,
|
|
||||||
fragment=parsed_uri.fragment,
|
|
||||||
)
|
|
||||||
|
|
||||||
res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
|
|
||||||
return res
|
|
||||||
|
|
||||||
# try a SRV lookup
|
|
||||||
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
|
|
||||||
server_list = yield self._srv_resolver.resolve_service(service_name)
|
|
||||||
|
|
||||||
if not server_list:
|
|
||||||
target_host = parsed_uri.host
|
|
||||||
port = 8448
|
|
||||||
logger.debug(
|
|
||||||
"No SRV record for %s, using %s:%i",
|
|
||||||
parsed_uri.host.decode("ascii"),
|
|
||||||
target_host.decode("ascii"),
|
|
||||||
port,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
target_host, port = pick_server_from_list(server_list)
|
|
||||||
logger.debug(
|
|
||||||
"Picked %s:%i from SRV records for %s",
|
|
||||||
target_host.decode("ascii"),
|
|
||||||
port,
|
|
||||||
parsed_uri.host.decode("ascii"),
|
|
||||||
)
|
|
||||||
|
|
||||||
return _RoutingResult(
|
|
||||||
host_header=parsed_uri.netloc,
|
|
||||||
tls_server_name=parsed_uri.host,
|
|
||||||
target_host=target_host,
|
|
||||||
target_port=port,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@implementer(IStreamClientEndpoint)
|
@implementer(IStreamClientEndpoint)
|
||||||
class LoggingHostnameEndpoint(object):
|
class MatrixHostnameEndpoint(object):
|
||||||
"""A wrapper for HostnameEndpint which logs when it connects"""
|
"""An endpoint that resolves matrix:// URLs using Matrix server name
|
||||||
|
resolution (i.e. via SRV). Does not check for well-known delegation.
|
||||||
|
|
||||||
def __init__(self, reactor, host, port, *args, **kwargs):
|
Args:
|
||||||
self.host = host
|
reactor (IReactor)
|
||||||
self.port = port
|
tls_client_options_factory (ClientTLSOptionsFactory|None):
|
||||||
self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
|
factory to use for fetching client tls options, or none to disable TLS.
|
||||||
|
srv_resolver (SrvResolver): The SRV resolver to use
|
||||||
|
parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
|
||||||
|
to connect to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
|
||||||
|
self._reactor = reactor
|
||||||
|
|
||||||
|
self._parsed_uri = parsed_uri
|
||||||
|
|
||||||
|
# set up the TLS connection params
|
||||||
|
#
|
||||||
|
# XXX disabling TLS is really only supported here for the benefit of the
|
||||||
|
# unit tests. We should make the UTs cope with TLS rather than having to make
|
||||||
|
# the code support the unit tests.
|
||||||
|
|
||||||
|
if tls_client_options_factory is None:
|
||||||
|
self._tls_options = None
|
||||||
|
else:
|
||||||
|
self._tls_options = tls_client_options_factory.get_options(
|
||||||
|
self._parsed_uri.host.decode("ascii")
|
||||||
|
)
|
||||||
|
|
||||||
|
self._srv_resolver = srv_resolver
|
||||||
|
|
||||||
def connect(self, protocol_factory):
|
def connect(self, protocol_factory):
|
||||||
logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
|
"""Implements IStreamClientEndpoint interface
|
||||||
return self.ep.connect(protocol_factory)
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
|
||||||
class _RoutingResult(object):
|
|
||||||
"""The result returned by `_route_matrix_uri`.
|
|
||||||
|
|
||||||
Contains the parameters needed to direct a federation connection to a particular
|
|
||||||
server.
|
|
||||||
|
|
||||||
Where a SRV record points to several servers, this object contains a single server
|
|
||||||
chosen from the list.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
host_header = attr.ib()
|
return run_in_background(self._do_connect, protocol_factory)
|
||||||
"""
|
|
||||||
The value we should assign to the Host header (host:port from the matrix
|
|
||||||
URI, or .well-known).
|
|
||||||
|
|
||||||
:type: bytes
|
@defer.inlineCallbacks
|
||||||
|
def _do_connect(self, protocol_factory):
|
||||||
|
first_exception = None
|
||||||
|
|
||||||
|
server_list = yield self._resolve_server()
|
||||||
|
|
||||||
|
for server in server_list:
|
||||||
|
host = server.host
|
||||||
|
port = server.port
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("Connecting to %s:%i", host.decode("ascii"), port)
|
||||||
|
endpoint = HostnameEndpoint(self._reactor, host, port)
|
||||||
|
if self._tls_options:
|
||||||
|
endpoint = wrapClientTLS(self._tls_options, endpoint)
|
||||||
|
result = yield make_deferred_yieldable(
|
||||||
|
endpoint.connect(protocol_factory)
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(
|
||||||
|
"Failed to connect to %s:%i: %s", host.decode("ascii"), port, e
|
||||||
|
)
|
||||||
|
if not first_exception:
|
||||||
|
first_exception = e
|
||||||
|
|
||||||
|
# We return the first failure because that's probably the most interesting.
|
||||||
|
if first_exception:
|
||||||
|
raise first_exception
|
||||||
|
|
||||||
|
# This shouldn't happen as we should always have at least one host/port
|
||||||
|
# to try and if that doesn't work then we'll have an exception.
|
||||||
|
raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _resolve_server(self):
|
||||||
|
"""Resolves the server name to a list of hosts and ports to attempt to
|
||||||
|
connect to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[list[Server]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tls_server_name = attr.ib()
|
if self._parsed_uri.scheme != b"matrix":
|
||||||
"""
|
return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)]
|
||||||
The server name we should set in the SNI (typically host, without port, from the
|
|
||||||
matrix URI or .well-known)
|
|
||||||
|
|
||||||
:type: bytes
|
# Note: We don't do well-known lookup as that needs to have happened
|
||||||
|
# before now, due to needing to rewrite the Host header of the HTTP
|
||||||
|
# request.
|
||||||
|
|
||||||
|
# We reparse the URI so that defaultPort is -1 rather than 80
|
||||||
|
parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes())
|
||||||
|
|
||||||
|
host = parsed_uri.hostname
|
||||||
|
port = parsed_uri.port
|
||||||
|
|
||||||
|
# If there is an explicit port or the host is an IP address we bypass
|
||||||
|
# SRV lookups and just use the given host/port.
|
||||||
|
if port or _is_ip_literal(host):
|
||||||
|
return [Server(host, port or 8448)]
|
||||||
|
|
||||||
|
server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
|
||||||
|
|
||||||
|
if server_list:
|
||||||
|
return server_list
|
||||||
|
|
||||||
|
# No SRV records, so we fallback to host and 8448
|
||||||
|
return [Server(host, 8448)]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ip_literal(host):
|
||||||
|
"""Test if the given host name is either an IPv4 or IPv6 literal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host (bytes)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
target_host = attr.ib()
|
host = host.decode("ascii")
|
||||||
"""
|
|
||||||
The hostname (or IP literal) we should route the TCP connection to (the target of the
|
|
||||||
SRV record, or the hostname from the URL/.well-known)
|
|
||||||
|
|
||||||
:type: bytes
|
try:
|
||||||
"""
|
IPAddress(host)
|
||||||
|
return True
|
||||||
target_port = attr.ib()
|
except AddrFormatError:
|
||||||
"""
|
return False
|
||||||
The port we should route the TCP connection to (the target of the SRV record, or
|
|
||||||
the port from the URL/.well-known, or 8448)
|
|
||||||
|
|
||||||
:type: int
|
|
||||||
"""
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
|
||||||
SERVER_CACHE = {}
|
SERVER_CACHE = {}
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s(slots=True, frozen=True)
|
||||||
class Server(object):
|
class Server(object):
|
||||||
"""
|
"""
|
||||||
Our record of an individual server which can be tried to reach a destination.
|
Our record of an individual server which can be tried to reach a destination.
|
||||||
|
@ -53,34 +53,47 @@ class Server(object):
|
||||||
expires = attr.ib(default=0)
|
expires = attr.ib(default=0)
|
||||||
|
|
||||||
|
|
||||||
def pick_server_from_list(server_list):
|
def _sort_server_list(server_list):
|
||||||
"""Randomly choose a server from the server list
|
"""Given a list of SRV records sort them into priority order and shuffle
|
||||||
|
each priority with the given weight.
|
||||||
Args:
|
|
||||||
server_list (list[Server]): list of candidate servers
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bytes, int]: (host, port) pair for the chosen server
|
|
||||||
"""
|
"""
|
||||||
if not server_list:
|
priority_map = {}
|
||||||
raise RuntimeError("pick_server_from_list called with empty list")
|
|
||||||
|
|
||||||
# TODO: currently we only use the lowest-priority servers. We should maintain a
|
for server in server_list:
|
||||||
# cache of servers known to be "down" and filter them out
|
priority_map.setdefault(server.priority, []).append(server)
|
||||||
|
|
||||||
min_priority = min(s.priority for s in server_list)
|
results = []
|
||||||
eligible_servers = list(s for s in server_list if s.priority == min_priority)
|
for priority in sorted(priority_map):
|
||||||
total_weight = sum(s.weight for s in eligible_servers)
|
servers = priority_map[priority]
|
||||||
target_weight = random.randint(0, total_weight)
|
|
||||||
|
|
||||||
for s in eligible_servers:
|
# This algorithms roughly follows the algorithm described in RFC2782,
|
||||||
|
# changed to remove an off-by-one error.
|
||||||
|
#
|
||||||
|
# N.B. Weights can be zero, which means that they should be picked
|
||||||
|
# rarely.
|
||||||
|
|
||||||
|
total_weight = sum(s.weight for s in servers)
|
||||||
|
|
||||||
|
# Total weight can become zero if there are only zero weight servers
|
||||||
|
# left, which we handle by just shuffling and appending to the results.
|
||||||
|
while servers and total_weight:
|
||||||
|
target_weight = random.randint(1, total_weight)
|
||||||
|
|
||||||
|
for s in servers:
|
||||||
target_weight -= s.weight
|
target_weight -= s.weight
|
||||||
|
|
||||||
if target_weight <= 0:
|
if target_weight <= 0:
|
||||||
return s.host, s.port
|
break
|
||||||
|
|
||||||
# this should be impossible.
|
results.append(s)
|
||||||
raise RuntimeError("pick_server_from_list got to end of eligible server list.")
|
servers.remove(s)
|
||||||
|
total_weight -= s.weight
|
||||||
|
|
||||||
|
if servers:
|
||||||
|
random.shuffle(servers)
|
||||||
|
results.extend(servers)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
class SrvResolver(object):
|
class SrvResolver(object):
|
||||||
|
@ -120,7 +133,7 @@ class SrvResolver(object):
|
||||||
if cache_entry:
|
if cache_entry:
|
||||||
if all(s.expires > now for s in cache_entry):
|
if all(s.expires > now for s in cache_entry):
|
||||||
servers = list(cache_entry)
|
servers = list(cache_entry)
|
||||||
return servers
|
return _sort_server_list(servers)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
answers, _, _ = yield make_deferred_yieldable(
|
answers, _, _ = yield make_deferred_yieldable(
|
||||||
|
@ -169,4 +182,4 @@ class SrvResolver(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
self._cache[service_name] = list(servers)
|
self._cache[service_name] = list(servers)
|
||||||
return servers
|
return _sort_server_list(servers)
|
||||||
|
|
|
@ -20,7 +20,6 @@ from synapse.federation.federation_server import server_matches_acl_event
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
@unittest.DEBUG
|
|
||||||
class ServerACLsTestCase(unittest.TestCase):
|
class ServerACLsTestCase(unittest.TestCase):
|
||||||
def test_blacklisted_server(self):
|
def test_blacklisted_server(self):
|
||||||
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
|
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
|
||||||
|
|
|
@ -41,9 +41,9 @@ from synapse.http.federation.well_known_resolver import (
|
||||||
from synapse.logging.context import LoggingContext
|
from synapse.logging.context import LoggingContext
|
||||||
from synapse.util.caches.ttlcache import TTLCache
|
from synapse.util.caches.ttlcache import TTLCache
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
|
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
|
||||||
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||||
from tests.unittest import TestCase
|
|
||||||
from tests.utils import default_config
|
from tests.utils import default_config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -67,7 +67,7 @@ def get_connection_factory():
|
||||||
return test_server_connection_factory
|
return test_server_connection_factory
|
||||||
|
|
||||||
|
|
||||||
class MatrixFederationAgentTests(TestCase):
|
class MatrixFederationAgentTests(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.reactor = ThreadedMemoryReactorClock()
|
self.reactor = ThreadedMemoryReactorClock()
|
||||||
|
|
||||||
|
@ -1069,8 +1069,64 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
r = self.successResultOf(fetch_d)
|
r = self.successResultOf(fetch_d)
|
||||||
self.assertEqual(r.delegated_server, None)
|
self.assertEqual(r.delegated_server, None)
|
||||||
|
|
||||||
|
def test_srv_fallbacks(self):
|
||||||
|
"""Test that other SRV results are tried if the first one fails.
|
||||||
|
"""
|
||||||
|
|
||||||
class TestCachePeriodFromHeaders(TestCase):
|
self.mock_resolver.resolve_service.side_effect = lambda _: [
|
||||||
|
Server(host=b"target.com", port=8443),
|
||||||
|
Server(host=b"target.com", port=8444),
|
||||||
|
]
|
||||||
|
self.reactor.lookups["target.com"] = "1.2.3.4"
|
||||||
|
|
||||||
|
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||||
|
b"_matrix._tcp.testserv"
|
||||||
|
)
|
||||||
|
|
||||||
|
# We should see an attempt to connect to the first server
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||||
|
self.assertEqual(host, "1.2.3.4")
|
||||||
|
self.assertEqual(port, 8443)
|
||||||
|
|
||||||
|
# Fonx the connection
|
||||||
|
client_factory.clientConnectionFailed(None, Exception("nope"))
|
||||||
|
|
||||||
|
# There's a 300ms delay in HostnameEndpoint
|
||||||
|
self.reactor.pump((0.4,))
|
||||||
|
|
||||||
|
# Hasn't failed yet
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
# We shouldnow see an attempt to connect to the second server
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||||
|
self.assertEqual(host, "1.2.3.4")
|
||||||
|
self.assertEqual(port, 8444)
|
||||||
|
|
||||||
|
# make a test server, and wire up the client
|
||||||
|
http_server = self._make_connection(client_factory, expected_sni=b"testserv")
|
||||||
|
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b"GET")
|
||||||
|
self.assertEqual(request.path, b"/foo/bar")
|
||||||
|
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
|
||||||
|
|
||||||
|
# finish the request
|
||||||
|
request.finish()
|
||||||
|
self.reactor.pump((0.1,))
|
||||||
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCachePeriodFromHeaders(unittest.TestCase):
|
||||||
def test_cache_control(self):
|
def test_cache_control(self):
|
||||||
# uppercase
|
# uppercase
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|
|
@ -83,8 +83,10 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
|
|
||||||
service_name = b"test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
entry = Mock(spec_set=["expires"])
|
entry = Mock(spec_set=["expires", "priority", "weight"])
|
||||||
entry.expires = 0
|
entry.expires = 0
|
||||||
|
entry.priority = 0
|
||||||
|
entry.weight = 0
|
||||||
|
|
||||||
cache = {service_name: [entry]}
|
cache = {service_name: [entry]}
|
||||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
@ -105,8 +107,10 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
|
|
||||||
service_name = b"test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
entry = Mock(spec_set=["expires"])
|
entry = Mock(spec_set=["expires", "priority", "weight"])
|
||||||
entry.expires = 999999999
|
entry.expires = 999999999
|
||||||
|
entry.priority = 0
|
||||||
|
entry.weight = 0
|
||||||
|
|
||||||
cache = {service_name: [entry]}
|
cache = {service_name: [entry]}
|
||||||
resolver = SrvResolver(
|
resolver = SrvResolver(
|
||||||
|
|
|
@ -74,7 +74,6 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||||
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
|
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
|
||||||
self.assertEqual(filtered[i].content["a"], "b")
|
self.assertEqual(filtered[i].content["a"], "b")
|
||||||
|
|
||||||
@tests.unittest.DEBUG
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_erased_user(self):
|
def test_erased_user(self):
|
||||||
# 4 message events, from erased and unerased users, with a membership
|
# 4 message events, from erased and unerased users, with a membership
|
||||||
|
|
Loading…
Reference in New Issue