Use `ParamSpec` in a few places (#12667)
parent
c5969b346d
commit
fa0eab9c8e
|
@ -0,0 +1 @@
|
||||||
|
Use `ParamSpec` to refine type hints.
|
|
@ -1563,7 +1563,7 @@ url_preview = ["lxml"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.7.1"
|
python-versions = "^3.7.1"
|
||||||
content-hash = "eebc9e1d720e2e866f5fddda98ce83d858949a6fdbe30c7e5aef4cf9d17be498"
|
content-hash = "d39d5ac5d51c014581186b7691999b861058b569084c525523baf70b77f292b1"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
attrs = [
|
attrs = [
|
||||||
|
|
|
@ -143,7 +143,9 @@ netaddr = ">=0.7.18"
|
||||||
Jinja2 = ">=3.0"
|
Jinja2 = ">=3.0"
|
||||||
bleach = ">=1.4.3"
|
bleach = ">=1.4.3"
|
||||||
# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
|
# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
|
||||||
typing-extensions = ">=3.10.0"
|
# Additionally we need https://github.com/python/typing/pull/817 to allow types to be
|
||||||
|
# generic over ParamSpecs.
|
||||||
|
typing-extensions = ">=3.10.0.1"
|
||||||
# We enforce that we have a `cryptography` version that bundles an `openssl`
|
# We enforce that we have a `cryptography` version that bundles an `openssl`
|
||||||
# with the latest security patches.
|
# with the latest security patches.
|
||||||
cryptography = ">=3.4.7"
|
cryptography = ">=3.4.7"
|
||||||
|
|
|
@ -38,6 +38,7 @@ from typing import (
|
||||||
|
|
||||||
from cryptography.utils import CryptographyDeprecationWarning
|
from cryptography.utils import CryptographyDeprecationWarning
|
||||||
from matrix_common.versionstring import get_distribution_version_string
|
from matrix_common.versionstring import get_distribution_version_string
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
import twisted
|
import twisted
|
||||||
from twisted.internet import defer, error, reactor as _reactor
|
from twisted.internet import defer, error, reactor as _reactor
|
||||||
|
@ -81,11 +82,12 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# list of tuples of function, args list, kwargs dict
|
# list of tuples of function, args list, kwargs dict
|
||||||
_sighup_callbacks: List[
|
_sighup_callbacks: List[
|
||||||
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
|
Tuple[Callable[..., None], Tuple[object, ...], Dict[str, object]]
|
||||||
] = []
|
] = []
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None:
|
def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Register a function to be called when a SIGHUP occurs.
|
Register a function to be called when a SIGHUP occurs.
|
||||||
|
|
||||||
|
@ -93,7 +95,9 @@ def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> Non
|
||||||
func: Function to be called when sent a SIGHUP signal.
|
func: Function to be called when sent a SIGHUP signal.
|
||||||
*args, **kwargs: args and kwargs to be passed to the target function.
|
*args, **kwargs: args and kwargs to be passed to the target function.
|
||||||
"""
|
"""
|
||||||
_sighup_callbacks.append((func, args, kwargs))
|
# This type-ignore should be redundant once we use a mypy release with
|
||||||
|
# https://github.com/python/mypy/pull/12668.
|
||||||
|
_sighup_callbacks.append((func, args, kwargs)) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
def start_worker_reactor(
|
def start_worker_reactor(
|
||||||
|
@ -214,7 +218,9 @@ def redirect_stdio_to_logs() -> None:
|
||||||
print("Redirected stdout/stderr to logs")
|
print("Redirected stdout/stderr to logs")
|
||||||
|
|
||||||
|
|
||||||
def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None:
|
def register_start(
|
||||||
|
cb: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
|
||||||
|
) -> None:
|
||||||
"""Register a callback with the reactor, to be called once it is running
|
"""Register a callback with the reactor, to be called once it is running
|
||||||
|
|
||||||
This can be used to initialise parts of the system which require an asynchronous
|
This can be used to initialise parts of the system which require an asynchronous
|
||||||
|
|
|
@ -22,9 +22,12 @@ from typing import (
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from synapse.api.presence import UserPresenceState
|
from synapse.api.presence import UserPresenceState
|
||||||
from synapse.util.async_helpers import maybe_awaitable
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
|
|
||||||
|
@ -40,6 +43,10 @@ GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
def load_legacy_presence_router(hs: "HomeServer") -> None:
|
def load_legacy_presence_router(hs: "HomeServer") -> None:
|
||||||
"""Wrapper that loads a presence router module configured using the old
|
"""Wrapper that loads a presence router module configured using the old
|
||||||
configuration, and registers the hooks they implement.
|
configuration, and registers the hooks they implement.
|
||||||
|
@ -63,13 +70,15 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
|
||||||
|
|
||||||
# All methods that the module provides should be async, but this wasn't enforced
|
# All methods that the module provides should be async, but this wasn't enforced
|
||||||
# in the old module system, so we wrap them if needed
|
# in the old module system, so we wrap them if needed
|
||||||
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
|
def async_wrapper(
|
||||||
|
f: Optional[Callable[P, R]]
|
||||||
|
) -> Optional[Callable[P, Awaitable[R]]]:
|
||||||
# f might be None if the callback isn't implemented by the module. In this
|
# f might be None if the callback isn't implemented by the module. In this
|
||||||
# case we don't want to register a callback at all so we return None.
|
# case we don't want to register a callback at all so we return None.
|
||||||
if f is None:
|
if f is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def run(*args: Any, **kwargs: Any) -> Awaitable:
|
def run(*args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
|
||||||
# Assertion required because mypy can't prove we won't change `f`
|
# Assertion required because mypy can't prove we won't change `f`
|
||||||
# back to `None`. See
|
# back to `None`. See
|
||||||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
||||||
|
@ -80,7 +89,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
|
||||||
return run
|
return run
|
||||||
|
|
||||||
# Register the hooks through the module API.
|
# Register the hooks through the module API.
|
||||||
hooks = {
|
hooks: Dict[str, Optional[Callable[..., Any]]] = {
|
||||||
hook: async_wrapper(getattr(presence_router, hook, None))
|
hook: async_wrapper(getattr(presence_router, hook, None))
|
||||||
for hook in presence_router_methods
|
for hook in presence_router_methods
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,7 @@ from typing import (
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import jinja2
|
import jinja2
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
@ -129,6 +130,7 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This package defines the 'stable' API which can be used by extension modules which
|
This package defines the 'stable' API which can be used by extension modules which
|
||||||
|
@ -799,9 +801,9 @@ class ModuleApi:
|
||||||
def run_db_interaction(
|
def run_db_interaction(
|
||||||
self,
|
self,
|
||||||
desc: str,
|
desc: str,
|
||||||
func: Callable[..., T],
|
func: Callable[P, T],
|
||||||
*args: Any,
|
*args: P.args,
|
||||||
**kwargs: Any,
|
**kwargs: P.kwargs,
|
||||||
) -> "defer.Deferred[T]":
|
) -> "defer.Deferred[T]":
|
||||||
"""Run a function with a database connection
|
"""Run a function with a database connection
|
||||||
|
|
||||||
|
@ -817,8 +819,9 @@ class ModuleApi:
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[object]: result of func
|
Deferred[object]: result of func
|
||||||
"""
|
"""
|
||||||
|
# type-ignore: See https://github.com/python/mypy/issues/8862
|
||||||
return defer.ensureDeferred(
|
return defer.ensureDeferred(
|
||||||
self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
|
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
def complete_sso_login(
|
def complete_sso_login(
|
||||||
|
@ -1296,9 +1299,9 @@ class ModuleApi:
|
||||||
|
|
||||||
async def defer_to_thread(
|
async def defer_to_thread(
|
||||||
self,
|
self,
|
||||||
f: Callable[..., T],
|
f: Callable[P, T],
|
||||||
*args: Any,
|
*args: P.args,
|
||||||
**kwargs: Any,
|
**kwargs: P.kwargs,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""Runs the given function in a separate thread from Synapse's thread pool.
|
"""Runs the given function in a separate thread from Synapse's thread pool.
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,6 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from twisted.web.server import Request
|
|
||||||
|
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
|
@ -97,7 +95,7 @@ class KnockRoomAliasServlet(RestServlet):
|
||||||
return 200, {"room_id": room_id}
|
return 200, {"room_id": room_id}
|
||||||
|
|
||||||
def on_PUT(
|
def on_PUT(
|
||||||
self, request: Request, room_identifier: str, txn_id: str
|
self, request: SynapseRequest, room_identifier: str, txn_id: str
|
||||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||||
set_tag("txn_id", txn_id)
|
set_tag("txn_id", txn_id)
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,9 @@
|
||||||
"""This module contains logic for storing HTTP PUT transactions. This is used
|
"""This module contains logic for storing HTTP PUT transactions. This is used
|
||||||
to ensure idempotency when performing PUTs using the REST API."""
|
to ensure idempotency when performing PUTs using the REST API."""
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple
|
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple
|
||||||
|
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
@ -32,6 +34,9 @@ logger = logging.getLogger(__name__)
|
||||||
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
|
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
|
||||||
|
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
class HttpTransactionCache:
|
class HttpTransactionCache:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
@ -65,9 +70,9 @@ class HttpTransactionCache:
|
||||||
def fetch_or_execute_request(
|
def fetch_or_execute_request(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
|
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
|
||||||
*args: Any,
|
*args: P.args,
|
||||||
**kwargs: Any,
|
**kwargs: P.kwargs,
|
||||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||||
"""A helper function for fetch_or_execute which extracts
|
"""A helper function for fetch_or_execute which extracts
|
||||||
a transaction key from the given request.
|
a transaction key from the given request.
|
||||||
|
@ -82,9 +87,9 @@ class HttpTransactionCache:
|
||||||
def fetch_or_execute(
|
def fetch_or_execute(
|
||||||
self,
|
self,
|
||||||
txn_key: str,
|
txn_key: str,
|
||||||
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
|
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
|
||||||
*args: Any,
|
*args: P.args,
|
||||||
**kwargs: Any,
|
**kwargs: P.kwargs,
|
||||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||||
"""Fetches the response for this transaction, or executes the given function
|
"""Fetches the response for this transaction, or executes the given function
|
||||||
to produce a response for this transaction.
|
to produce a response for this transaction.
|
||||||
|
|
|
@ -192,7 +192,7 @@ class LoggingDatabaseConnection:
|
||||||
|
|
||||||
|
|
||||||
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
|
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
|
||||||
_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
|
_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
@ -239,7 +239,9 @@ class LoggingTransaction:
|
||||||
self.after_callbacks = after_callbacks
|
self.after_callbacks = after_callbacks
|
||||||
self.exception_callbacks = exception_callbacks
|
self.exception_callbacks = exception_callbacks
|
||||||
|
|
||||||
def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
|
def call_after(
|
||||||
|
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
|
||||||
|
) -> None:
|
||||||
"""Call the given callback on the main twisted thread after the transaction has
|
"""Call the given callback on the main twisted thread after the transaction has
|
||||||
finished.
|
finished.
|
||||||
|
|
||||||
|
@ -256,11 +258,12 @@ class LoggingTransaction:
|
||||||
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||||
# is not the case.
|
# is not the case.
|
||||||
assert self.after_callbacks is not None
|
assert self.after_callbacks is not None
|
||||||
self.after_callbacks.append((callback, args, kwargs))
|
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
|
||||||
|
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
|
||||||
|
|
||||||
def call_on_exception(
|
def call_on_exception(
|
||||||
self, callback: Callable[..., object], *args: Any, **kwargs: Any
|
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
|
||||||
):
|
) -> None:
|
||||||
"""Call the given callback on the main twisted thread after the transaction has
|
"""Call the given callback on the main twisted thread after the transaction has
|
||||||
failed.
|
failed.
|
||||||
|
|
||||||
|
@ -274,7 +277,8 @@ class LoggingTransaction:
|
||||||
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||||
# is not the case.
|
# is not the case.
|
||||||
assert self.exception_callbacks is not None
|
assert self.exception_callbacks is not None
|
||||||
self.exception_callbacks.append((callback, args, kwargs))
|
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
|
||||||
|
self.exception_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
|
||||||
|
|
||||||
def fetchone(self) -> Optional[Tuple]:
|
def fetchone(self) -> Optional[Tuple]:
|
||||||
return self.txn.fetchone()
|
return self.txn.fetchone()
|
||||||
|
@ -549,9 +553,9 @@ class DatabasePool:
|
||||||
desc: str,
|
desc: str,
|
||||||
after_callbacks: List[_CallbackListEntry],
|
after_callbacks: List[_CallbackListEntry],
|
||||||
exception_callbacks: List[_CallbackListEntry],
|
exception_callbacks: List[_CallbackListEntry],
|
||||||
func: Callable[..., R],
|
func: Callable[Concatenate[LoggingTransaction, P], R],
|
||||||
*args: Any,
|
*args: P.args,
|
||||||
**kwargs: Any,
|
**kwargs: P.kwargs,
|
||||||
) -> R:
|
) -> R:
|
||||||
"""Start a new database transaction with the given connection.
|
"""Start a new database transaction with the given connection.
|
||||||
|
|
||||||
|
@ -581,7 +585,10 @@ class DatabasePool:
|
||||||
# will fail if we have to repeat the transaction.
|
# will fail if we have to repeat the transaction.
|
||||||
# For now, we just log an error, and hope that it works on the first attempt.
|
# For now, we just log an error, and hope that it works on the first attempt.
|
||||||
# TODO: raise an exception.
|
# TODO: raise an exception.
|
||||||
for i, arg in enumerate(args):
|
|
||||||
|
# Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see
|
||||||
|
# https://github.com/python/mypy/pull/12668
|
||||||
|
for i, arg in enumerate(args): # type: ignore[arg-type, var-annotated]
|
||||||
if inspect.isgenerator(arg):
|
if inspect.isgenerator(arg):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Programming error: generator passed to new_transaction as "
|
"Programming error: generator passed to new_transaction as "
|
||||||
|
@ -589,7 +596,9 @@ class DatabasePool:
|
||||||
i,
|
i,
|
||||||
func,
|
func,
|
||||||
)
|
)
|
||||||
for name, val in kwargs.items():
|
# Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see
|
||||||
|
# https://github.com/python/mypy/pull/12668
|
||||||
|
for name, val in kwargs.items(): # type: ignore[attr-defined]
|
||||||
if inspect.isgenerator(val):
|
if inspect.isgenerator(val):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Programming error: generator passed to new_transaction as "
|
"Programming error: generator passed to new_transaction as "
|
||||||
|
|
|
@ -1648,8 +1648,12 @@ class PersistEventsStore:
|
||||||
txn.call_after(prefill)
|
txn.call_after(prefill)
|
||||||
|
|
||||||
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
|
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
|
||||||
# Invalidate the caches for the redacted event, note that these caches
|
"""Invalidate the caches for the redacted event.
|
||||||
# are also cleared as part of event replication in _invalidate_caches_for_event.
|
|
||||||
|
Note that these caches are also cleared as part of event replication in
|
||||||
|
_invalidate_caches_for_event.
|
||||||
|
"""
|
||||||
|
assert event.redacts is not None
|
||||||
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
|
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
|
||||||
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
|
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
|
||||||
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
|
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
|
||||||
|
|
|
@ -42,7 +42,7 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import AsyncContextManager, Literal
|
from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.defer import CancelledError
|
from twisted.internet.defer import CancelledError
|
||||||
|
@ -237,9 +237,16 @@ async def concurrently_execute(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
async def yieldable_gather_results(
|
async def yieldable_gather_results(
|
||||||
func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any
|
func: Callable[Concatenate[T, P], Awaitable[R]],
|
||||||
) -> List[T]:
|
iter: Iterable[T],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> List[R]:
|
||||||
"""Executes the function with each argument concurrently.
|
"""Executes the function with each argument concurrently.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -255,7 +262,15 @@ async def yieldable_gather_results(
|
||||||
try:
|
try:
|
||||||
return await make_deferred_yieldable(
|
return await make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
defer.gatherResults(
|
||||||
[run_in_background(func, item, *args, **kwargs) for item in iter],
|
# type-ignore: mypy reports two errors:
|
||||||
|
# error: Argument 1 to "run_in_background" has incompatible type
|
||||||
|
# "Callable[[T, **P], Awaitable[R]]"; expected
|
||||||
|
# "Callable[[T, **P], Awaitable[R]]" [arg-type]
|
||||||
|
# error: Argument 2 to "run_in_background" has incompatible type
|
||||||
|
# "T"; expected "[T, **P.args]" [arg-type]
|
||||||
|
# The former looks like a mypy bug, and the latter looks like a
|
||||||
|
# false positive.
|
||||||
|
[run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type]
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -577,9 +592,6 @@ class ReadWriteLock:
|
||||||
return _ctx_manager()
|
return _ctx_manager()
|
||||||
|
|
||||||
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
|
||||||
def timeout_deferred(
|
def timeout_deferred(
|
||||||
deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime
|
deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime
|
||||||
) -> "defer.Deferred[_T]":
|
) -> "defer.Deferred[_T]":
|
||||||
|
|
|
@ -12,7 +12,19 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -75,7 +87,11 @@ class Distributor:
|
||||||
run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
|
run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Signal:
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
class Signal(Generic[P]):
|
||||||
"""A Signal is a dispatch point that stores a list of callables as
|
"""A Signal is a dispatch point that stores a list of callables as
|
||||||
observers of it.
|
observers of it.
|
||||||
|
|
||||||
|
@ -87,16 +103,16 @@ class Signal:
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
self.name: str = name
|
self.name: str = name
|
||||||
self.observers: List[Callable] = []
|
self.observers: List[Callable[P, Any]] = []
|
||||||
|
|
||||||
def observe(self, observer: Callable) -> None:
|
def observe(self, observer: Callable[P, Any]) -> None:
|
||||||
"""Adds a new callable to the observer list which will be invoked by
|
"""Adds a new callable to the observer list which will be invoked by
|
||||||
the 'fire' method.
|
the 'fire' method.
|
||||||
|
|
||||||
Each observer callable may return a Deferred."""
|
Each observer callable may return a Deferred."""
|
||||||
self.observers.append(observer)
|
self.observers.append(observer)
|
||||||
|
|
||||||
def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
|
def fire(self, *args: P.args, **kwargs: P.kwargs) -> "defer.Deferred[List[Any]]":
|
||||||
"""Invokes every callable in the observer list, passing in the args and
|
"""Invokes every callable in the observer list, passing in the args and
|
||||||
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
||||||
not an error to fire a signal with no observers.
|
not an error to fire a signal with no observers.
|
||||||
|
@ -104,7 +120,7 @@ class Signal:
|
||||||
Returns a Deferred that will complete when all the observers have
|
Returns a Deferred that will complete when all the observers have
|
||||||
completed."""
|
completed."""
|
||||||
|
|
||||||
async def do(observer: Callable[..., Any]) -> Any:
|
async def do(observer: Callable[P, Union[R, Awaitable[R]]]) -> Optional[R]:
|
||||||
try:
|
try:
|
||||||
return await maybe_awaitable(observer(*args, **kwargs))
|
return await maybe_awaitable(observer(*args, **kwargs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -114,6 +130,7 @@ class Signal:
|
||||||
observer,
|
observer,
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
deferreds = [run_in_background(do, o) for o in self.observers]
|
deferreds = [run_in_background(do, o) for o in self.observers]
|
||||||
|
|
||||||
|
|
|
@ -15,10 +15,10 @@
|
||||||
import logging
|
import logging
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, Callable, Optional, Type, TypeVar, cast
|
from typing import Awaitable, Callable, Optional, Type, TypeVar
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Concatenate, ParamSpec, Protocol
|
||||||
|
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
ContextResourceUsage,
|
ContextResourceUsage,
|
||||||
|
@ -72,16 +72,21 @@ in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=Callable[..., Any])
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
class HasClock(Protocol):
|
class HasClock(Protocol):
|
||||||
clock: Clock
|
clock: Clock
|
||||||
|
|
||||||
|
|
||||||
def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
|
def measure_func(
|
||||||
"""
|
name: Optional[str] = None,
|
||||||
Used to decorate an async function with a `Measure` context manager.
|
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
|
||||||
|
"""Decorate an async method with a `Measure` context manager.
|
||||||
|
|
||||||
|
The Measure is created using `self.clock`; it should only be used to decorate
|
||||||
|
methods in classes defining an instance-level `clock` attribute.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
|
@ -97,18 +102,24 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(func: T) -> T:
|
def wrapper(
|
||||||
|
func: Callable[Concatenate[HasClock, P], Awaitable[R]]
|
||||||
|
) -> Callable[P, Awaitable[R]]:
|
||||||
block_name = func.__name__ if name is None else name
|
block_name = func.__name__ if name is None else name
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any:
|
async def measured_func(self: HasClock, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
with Measure(self.clock, block_name):
|
with Measure(self.clock, block_name):
|
||||||
r = await func(self, *args, **kwargs)
|
r = await func(self, *args, **kwargs)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
return cast(T, measured_func)
|
# There are some shenanigans here, because we're decorating a method but
|
||||||
|
# explicitly making use of the `self` parameter. The key thing here is that the
|
||||||
|
# return type within the return type for `measure_func` itself describes how the
|
||||||
|
# decorated function will be called.
|
||||||
|
return measured_func # type: ignore[return-value]
|
||||||
|
|
||||||
return wrapper
|
return wrapper # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
class Measure:
|
class Measure:
|
||||||
|
|
|
@ -16,6 +16,8 @@ import functools
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Callable, Generator, List, TypeVar, cast
|
from typing import Any, Callable, Generator, List, TypeVar, cast
|
||||||
|
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
@ -25,6 +27,7 @@ _already_patched = False
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
def do_patch() -> None:
|
def do_patch() -> None:
|
||||||
|
@ -41,13 +44,13 @@ def do_patch() -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def new_inline_callbacks(
|
def new_inline_callbacks(
|
||||||
f: Callable[..., Generator["Deferred[object]", object, T]]
|
f: Callable[P, Generator["Deferred[object]", object, T]]
|
||||||
) -> Callable[..., "Deferred[T]"]:
|
) -> Callable[P, "Deferred[T]"]:
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]":
|
def wrapped(*args: P.args, **kwargs: P.kwargs) -> "Deferred[T]":
|
||||||
start_context = current_context()
|
start_context = current_context()
|
||||||
changes: List[str] = []
|
changes: List[str] = []
|
||||||
orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks(
|
orig: Callable[P, "Deferred[T]"] = orig_inline_callbacks(
|
||||||
_check_yield_points(f, changes)
|
_check_yield_points(f, changes)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -115,7 +118,7 @@ def do_patch() -> None:
|
||||||
|
|
||||||
|
|
||||||
def _check_yield_points(
|
def _check_yield_points(
|
||||||
f: Callable[..., Generator["Deferred[object]", object, T]],
|
f: Callable[P, Generator["Deferred[object]", object, T]],
|
||||||
changes: List[str],
|
changes: List[str],
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""Wraps a generator that is about to be passed to defer.inlineCallbacks
|
"""Wraps a generator that is about to be passed to defer.inlineCallbacks
|
||||||
|
@ -138,7 +141,7 @@ def _check_yield_points(
|
||||||
|
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def check_yield_points_inner(
|
def check_yield_points_inner(
|
||||||
*args: Any, **kwargs: Any
|
*args: P.args, **kwargs: P.kwargs
|
||||||
) -> Generator["Deferred[object]", object, T]:
|
) -> Generator["Deferred[object]", object, T]:
|
||||||
gen = f(*args, **kwargs)
|
gen = f(*args, **kwargs)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue