Type hints for the remaining two files in `synapse.http`. (#11164)

* Teach MyPy that the sentinel context is False

This means that if `ctx: LoggingContextOrSentinel`
then `bool(ctx)` narrows us to `ctx:LoggingContext`, which is a really
neat find!

* Annotate RequestMetrics

- Raise errors for sentry if we use the sentinel context
- Ensure we don't raise an error and carry on, but not recording stats
- Include stack trace in the error case to lower Sean's blood pressure

* Make mypy pass for synapse.http.request_metrics

* Make synapse.http.connectproxyclient pass mypy

Co-authored-by: reivilibre <oliverw@matrix.org>
pull/11206/head
David Robertson 2021-10-28 14:14:42 +01:00 committed by GitHub
parent a19bf32a03
commit 1bfd141205
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 56 additions and 37 deletions

1
changelog.d/11164.misc Normal file
View File

@ -0,0 +1 @@
Add type hints so that `synapse.http` passes `mypy` checks.

View File

@ -16,6 +16,7 @@ no_implicit_optional = True
files =
scripts-dev/sign_json,
synapse/__init__.py,
synapse/api,
synapse/appservice,
synapse/config,
@ -31,16 +32,7 @@ files =
synapse/federation,
synapse/groups,
synapse/handlers,
synapse/http/additional_resource.py,
synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py,
synapse/http/federation/srv_resolver.py,
synapse/http/federation/well_known_resolver.py,
synapse/http/matrixfederationclient.py,
synapse/http/proxyagent.py,
synapse/http/servlet.py,
synapse/http/server.py,
synapse/http/site.py,
synapse/http,
synapse/logging,
synapse/metrics,
synapse/module_api,

View File

@ -84,7 +84,11 @@ class HTTPConnectProxyEndpoint:
def __repr__(self):
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
def connect(self, protocolFactory: ClientFactory):
# Mypy encounters a false positive here: it complains that ClientFactory
# is incompatible with IProtocolFactory. But ClientFactory inherits from
# Factory, which implements IProtocolFactory. So I think this is a bug
# in mypy-zope.
def connect(self, protocolFactory: ClientFactory): # type: ignore[override]
f = HTTPProxiedClientFactory(
self._host, self._port, protocolFactory, self._proxy_creds
)
@ -119,13 +123,15 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
self.dst_port = dst_port
self.wrapped_factory = wrapped_factory
self.proxy_creds = proxy_creds
self.on_connection = defer.Deferred()
self.on_connection: "defer.Deferred[None]" = defer.Deferred()
def startedConnecting(self, connector):
return self.wrapped_factory.startedConnecting(connector)
def buildProtocol(self, addr):
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
if wrapped_protocol is None:
raise TypeError("buildProtocol produced None instead of a Protocol")
return HTTPConnectProtocol(
self.dst_host,
@ -235,7 +241,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
self.host = host
self.port = port
self.proxy_creds = proxy_creds
self.on_connected = defer.Deferred()
self.on_connected: "defer.Deferred[None]" = defer.Deferred()
def connectionMade(self):
logger.debug("Connected to proxy, sending CONNECT")

View File

@ -15,6 +15,8 @@
import logging
import threading
import traceback
from typing import Dict, Mapping, Set, Tuple
from prometheus_client.core import Counter, Histogram
@ -105,19 +107,14 @@ in_flight_requests_db_sched_duration = Counter(
["method", "servlet"],
)
# The set of all in flight requests, set[RequestMetrics]
_in_flight_requests = set()
_in_flight_requests: Set["RequestMetrics"] = set()
# Protects the _in_flight_requests set from concurrent access
_in_flight_requests_lock = threading.Lock()
def _get_in_flight_counts():
"""Returns a count of all in flight requests by (method, server_name)
Returns:
dict[tuple[str, str], int]
"""
def _get_in_flight_counts() -> Mapping[Tuple[str, ...], int]:
"""Returns a count of all in flight requests by (method, server_name)"""
# Cast to a list to prevent it changing while the Prometheus
# thread is collecting metrics
with _in_flight_requests_lock:
@ -127,8 +124,9 @@ def _get_in_flight_counts():
rm.update_metrics()
# Map from (method, name) -> int, the number of in flight requests of that
# type
counts = {}
# type. The key type is Tuple[str, str], but we leave the length unspecified
# for compatability with LaterGauge's annotations.
counts: Dict[Tuple[str, ...], int] = {}
for rm in reqs:
key = (rm.method, rm.name)
counts[key] = counts.get(key, 0) + 1
@ -145,15 +143,21 @@ LaterGauge(
class RequestMetrics:
def start(self, time_sec, name, method):
self.start = time_sec
def start(self, time_sec: float, name: str, method: str) -> None:
self.start_ts = time_sec
self.start_context = current_context()
self.name = name
self.method = method
# _request_stats records resource usage that we have already added
# to the "in flight" metrics.
self._request_stats = self.start_context.get_resource_usage()
if self.start_context:
# _request_stats records resource usage that we have already added
# to the "in flight" metrics.
self._request_stats = self.start_context.get_resource_usage()
else:
logger.error(
"Tried to start a RequestMetric from the sentinel context.\n%s",
"".join(traceback.format_stack()),
)
with _in_flight_requests_lock:
_in_flight_requests.add(self)
@ -169,12 +173,18 @@ class RequestMetrics:
tag = context.tag
if context != self.start_context:
logger.warning(
logger.error(
"Context have unexpectedly changed %r, %r",
context,
self.start_context,
)
return
else:
logger.error(
"Trying to stop RequestMetrics in the sentinel context.\n%s",
"".join(traceback.format_stack()),
)
return
response_code = str(response_code)
@ -183,7 +193,7 @@ class RequestMetrics:
response_count.labels(self.method, self.name, tag).inc()
response_timer.labels(self.method, self.name, tag, response_code).observe(
time_sec - self.start
time_sec - self.start_ts
)
resource_usage = context.get_resource_usage()
@ -213,6 +223,12 @@ class RequestMetrics:
def update_metrics(self):
"""Updates the in flight metrics with values from this request."""
if not self.start_context:
logger.error(
"Tried to update a RequestMetric from the sentinel context.\n%s",
"".join(traceback.format_stack()),
)
return
new_stats = self.start_context.get_resource_usage()
diff = new_stats - self._request_stats

View File

@ -220,7 +220,7 @@ class _Sentinel:
self.scope = None
self.tag = None
def __str__(self):
def __str__(self) -> str:
return "sentinel"
def copy_to(self, record):
@ -241,7 +241,7 @@ class _Sentinel:
def record_event_fetch(self, event_count):
pass
def __bool__(self):
def __bool__(self) -> Literal[False]:
return False

View File

@ -20,7 +20,7 @@ import os
import platform
import threading
import time
from typing import Callable, Dict, Iterable, Optional, Tuple, Union
from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
import attr
from prometheus_client import Counter, Gauge, Histogram
@ -67,7 +67,11 @@ class LaterGauge:
labels = attr.ib(hash=False, type=Optional[Iterable[str]])
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
caller = attr.ib(type=Callable[[], Union[Dict[Tuple[str, ...], float], float]])
caller = attr.ib(
type=Callable[
[], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
]
)
def collect(self):
@ -80,11 +84,11 @@ class LaterGauge:
yield g
return
if isinstance(calls, dict):
if isinstance(calls, (int, float)):
g.add_metric([], calls)
else:
for k, v in calls.items():
g.add_metric(k, v)
else:
g.add_metric([], calls)
yield g