diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 4775f6707d..e80d00e2af 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError import collections import logging import random +import time logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ SERVER_CACHE = {} _Server = collections.namedtuple( - "_Server", "priority weight host port" + "_Server", "priority weight host port expires" ) @@ -92,7 +93,8 @@ class SRVClientEndpoint(object): host=domain, port=default_port, priority=0, - weight=0 + weight=0, + expires=0, ) else: self.default_server = None @@ -154,6 +156,12 @@ class SRVClientEndpoint(object): @defer.inlineCallbacks def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): + cache_entry = cache.get(service_name, None) + if cache_entry: + if all(s.expires > int(time.time()) for s in cache_entry): + servers = list(cache_entry) + defer.returnValue(servers) + servers = [] try: @@ -173,27 +181,26 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): continue payload = answer.payload - host = str(payload.target) + srv_ttl = answer.ttl try: answers, _, _ = yield dns_client.lookupAddress(host) except DNSNameError: continue - ips = [ - answer.payload.dottedQuad() - for answer in answers - if answer.type == dns.A and answer.payload - ] + for answer in answers: + if answer.type == dns.A and answer.payload: + ip = answer.payload.dottedQuad() + host_ttl = min(srv_ttl, answer.ttl) - for ip in ips: - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight) - )) + servers.append(_Server( + host=ip, + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(time.time()) + host_ttl, + )) servers.sort() cache[service_name] = list(servers) diff --git a/tests/test_dns.py b/tests/test_dns.py index 637b1606f8..e006ed1a59 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -69,8 +69,11 @@ class DnsTestCase(unittest.TestCase): service_name = "test_service.examle.com" + entry = Mock(spec_set=["expires"]) + entry.expires = 999999999 + cache = { - service_name: [object()] + service_name: [entry] } servers = yield resolve_service(