Additional type hints for the proxy agent and SRV resolver modules. (#10608)

pull/10613/head
Dirk Klimpel 2021-08-18 19:53:20 +02:00 committed by GitHub
parent 78a70a2e0b
commit 0c3565da4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 25 deletions

1
changelog.d/10608.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints for the proxy agent and SRV resolver modules. Contributed by @dklimpel.

View File

@ -28,10 +28,13 @@ files =
synapse/federation, synapse/federation,
synapse/groups, synapse/groups,
synapse/handlers, synapse/handlers,
synapse/http/additional_resource.py,
synapse/http/client.py, synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py, synapse/http/federation/matrix_federation_agent.py,
synapse/http/federation/srv_resolver.py,
synapse/http/federation/well_known_resolver.py, synapse/http/federation/well_known_resolver.py,
synapse/http/matrixfederationclient.py, synapse/http/matrixfederationclient.py,
synapse/http/proxyagent.py,
synapse/http/servlet.py, synapse/http/servlet.py,
synapse/http/server.py, synapse/http/server.py,
synapse/http/site.py, synapse/http/site.py,

View File

@ -12,8 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource from synapse.http.server import DirectServeJsonResource
if TYPE_CHECKING:
from synapse.server import HomeServer
class AdditionalResource(DirectServeJsonResource): class AdditionalResource(DirectServeJsonResource):
"""Resource wrapper for additional_resources """Resource wrapper for additional_resources
@ -25,7 +32,7 @@ class AdditionalResource(DirectServeJsonResource):
and exception handling. and exception handling.
""" """
def __init__(self, hs, handler): def __init__(self, hs: "HomeServer", handler):
"""Initialise AdditionalResource """Initialise AdditionalResource
The ``handler`` should return a deferred which completes when it has The ``handler`` should return a deferred which completes when it has
@ -33,14 +40,14 @@ class AdditionalResource(DirectServeJsonResource):
``request.write()``, and call ``request.finish()``. ``request.write()``, and call ``request.finish()``.
Args: Args:
hs (synapse.server.HomeServer): homeserver hs: homeserver
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
function to be called to handle the request. function to be called to handle the request.
""" """
super().__init__() super().__init__()
self._handler = handler self._handler = handler
def _async_render(self, request): def _async_render(self, request: Request):
# Cheekily pass the result straight through, so we don't need to worry # Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not. # if its an awaitable or not.
return self._handler(request) return self._handler(request)

View File

@ -16,7 +16,7 @@
import logging import logging
import random import random
import time import time
from typing import List from typing import Callable, Dict, List
import attr import attr
@ -28,35 +28,35 @@ from synapse.logging.context import make_deferred_yieldable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SERVER_CACHE = {} SERVER_CACHE: Dict[bytes, List["Server"]] = {}
@attr.s(slots=True, frozen=True) @attr.s(auto_attribs=True, slots=True, frozen=True)
class Server: class Server:
""" """
Our record of an individual server which can be tried to reach a destination. Our record of an individual server which can be tried to reach a destination.
Attributes: Attributes:
host (bytes): target hostname host: target hostname
port (int): port:
priority (int): priority:
weight (int): weight:
expires (int): when the cache should expire this record - in *seconds* since expires: when the cache should expire this record - in *seconds* since
the epoch the epoch
""" """
host = attr.ib() host: bytes
port = attr.ib() port: int
priority = attr.ib(default=0) priority: int = 0
weight = attr.ib(default=0) weight: int = 0
expires = attr.ib(default=0) expires: int = 0
def _sort_server_list(server_list): def _sort_server_list(server_list: List[Server]) -> List[Server]:
"""Given a list of SRV records sort them into priority order and shuffle """Given a list of SRV records sort them into priority order and shuffle
each priority with the given weight. each priority with the given weight.
""" """
priority_map = {} priority_map: Dict[int, List[Server]] = {}
for server in server_list: for server in server_list:
priority_map.setdefault(server.priority, []).append(server) priority_map.setdefault(server.priority, []).append(server)
@ -103,11 +103,16 @@ class SrvResolver:
Args: Args:
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object cache: cache object
get_time (callable): clock implementation. Should return seconds since the epoch get_time: clock implementation. Should return seconds since the epoch
""" """
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time): def __init__(
self,
dns_client=client,
cache: Dict[bytes, List[Server]] = SERVER_CACHE,
get_time: Callable[[], float] = time.time,
):
self._dns_client = dns_client self._dns_client = dns_client
self._cache = cache self._cache = cache
self._get_time = get_time self._get_time = get_time
@ -116,7 +121,7 @@ class SrvResolver:
"""Look up a SRV record """Look up a SRV record
Args: Args:
service_name (bytes): record to look up service_name: record to look up
Returns: Returns:
a list of the SRV records, or an empty list if none found a list of the SRV records, or an empty list if none found
@ -158,7 +163,7 @@ class SrvResolver:
and answers[0].payload and answers[0].payload
and answers[0].payload.target == dns.Name(b".") and answers[0].payload.target == dns.Name(b".")
): ):
raise ConnectError("Service %s unavailable" % service_name) raise ConnectError(f"Service {service_name!r} unavailable")
servers = [] servers = []

View File

@ -173,7 +173,7 @@ class ProxyAgent(_AgentBase):
raise ValueError(f"Invalid URI {uri!r}") raise ValueError(f"Invalid URI {uri!r}")
parsed_uri = URI.fromBytes(uri) parsed_uri = URI.fromBytes(uri)
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) pool_key = f"{parsed_uri.scheme!r}{parsed_uri.host!r}{parsed_uri.port}"
request_path = parsed_uri.originForm request_path = parsed_uri.originForm
should_skip_proxy = False should_skip_proxy = False
@ -199,7 +199,7 @@ class ProxyAgent(_AgentBase):
) )
# Cache *all* connections under the same key, since we are only # Cache *all* connections under the same key, since we are only
# connecting to a single destination, the proxy: # connecting to a single destination, the proxy:
pool_key = ("http-proxy", self.http_proxy_endpoint) pool_key = "http-proxy"
endpoint = self.http_proxy_endpoint endpoint = self.http_proxy_endpoint
request_path = uri request_path = uri
elif ( elif (