Use `ParamSpec` in a few places (#12667)

pull/12679/head
David Robertson 2022-05-09 11:27:39 +01:00 committed by GitHub
parent c5969b346d
commit fa0eab9c8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 148 additions and 68 deletions

1
changelog.d/12667.misc Normal file
View File

@ -0,0 +1 @@
Use `ParamSpec` to refine type hints.

2
poetry.lock generated
View File

@ -1563,7 +1563,7 @@ url_preview = ["lxml"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.7.1" python-versions = "^3.7.1"
content-hash = "eebc9e1d720e2e866f5fddda98ce83d858949a6fdbe30c7e5aef4cf9d17be498" content-hash = "d39d5ac5d51c014581186b7691999b861058b569084c525523baf70b77f292b1"
[metadata.files] [metadata.files]
attrs = [ attrs = [

View File

@ -143,7 +143,9 @@ netaddr = ">=0.7.18"
Jinja2 = ">=3.0" Jinja2 = ">=3.0"
bleach = ">=1.4.3" bleach = ">=1.4.3"
# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0. # We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
typing-extensions = ">=3.10.0" # Additionally we need https://github.com/python/typing/pull/817 to allow types to be
# generic over ParamSpecs.
typing-extensions = ">=3.10.0.1"
# We enforce that we have a `cryptography` version that bundles an `openssl` # We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches. # with the latest security patches.
cryptography = ">=3.4.7" cryptography = ">=3.4.7"

View File

@ -38,6 +38,7 @@ from typing import (
from cryptography.utils import CryptographyDeprecationWarning from cryptography.utils import CryptographyDeprecationWarning
from matrix_common.versionstring import get_distribution_version_string from matrix_common.versionstring import get_distribution_version_string
from typing_extensions import ParamSpec
import twisted import twisted
from twisted.internet import defer, error, reactor as _reactor from twisted.internet import defer, error, reactor as _reactor
@ -81,11 +82,12 @@ logger = logging.getLogger(__name__)
# list of tuples of function, args list, kwargs dict # list of tuples of function, args list, kwargs dict
_sighup_callbacks: List[ _sighup_callbacks: List[
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]] Tuple[Callable[..., None], Tuple[object, ...], Dict[str, object]]
] = [] ] = []
P = ParamSpec("P")
def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None: def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None:
""" """
Register a function to be called when a SIGHUP occurs. Register a function to be called when a SIGHUP occurs.
@ -93,7 +95,9 @@ def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> Non
func: Function to be called when sent a SIGHUP signal. func: Function to be called when sent a SIGHUP signal.
*args, **kwargs: args and kwargs to be passed to the target function. *args, **kwargs: args and kwargs to be passed to the target function.
""" """
_sighup_callbacks.append((func, args, kwargs)) # This type-ignore should be redundant once we use a mypy release with
# https://github.com/python/mypy/pull/12668.
_sighup_callbacks.append((func, args, kwargs)) # type: ignore[arg-type]
def start_worker_reactor( def start_worker_reactor(
@ -214,7 +218,9 @@ def redirect_stdio_to_logs() -> None:
print("Redirected stdout/stderr to logs") print("Redirected stdout/stderr to logs")
def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None: def register_start(
cb: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
) -> None:
"""Register a callback with the reactor, to be called once it is running """Register a callback with the reactor, to be called once it is running
This can be used to initialise parts of the system which require an asynchronous This can be used to initialise parts of the system which require an asynchronous

View File

@ -22,9 +22,12 @@ from typing import (
List, List,
Optional, Optional,
Set, Set,
TypeVar,
Union, Union,
) )
from typing_extensions import ParamSpec
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import maybe_awaitable
@ -40,6 +43,10 @@ GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
def load_legacy_presence_router(hs: "HomeServer") -> None: def load_legacy_presence_router(hs: "HomeServer") -> None:
"""Wrapper that loads a presence router module configured using the old """Wrapper that loads a presence router module configured using the old
configuration, and registers the hooks they implement. configuration, and registers the hooks they implement.
@ -63,13 +70,15 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
# All methods that the module provides should be async, but this wasn't enforced # All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed # in the old module system, so we wrap them if needed
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: def async_wrapper(
f: Optional[Callable[P, R]]
) -> Optional[Callable[P, Awaitable[R]]]:
# f might be None if the callback isn't implemented by the module. In this # f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None. # case we don't want to register a callback at all so we return None.
if f is None: if f is None:
return None return None
def run(*args: Any, **kwargs: Any) -> Awaitable: def run(*args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
# Assertion required because mypy can't prove we won't change `f` # Assertion required because mypy can't prove we won't change `f`
# back to `None`. See # back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
@ -80,7 +89,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
return run return run
# Register the hooks through the module API. # Register the hooks through the module API.
hooks = { hooks: Dict[str, Optional[Callable[..., Any]]] = {
hook: async_wrapper(getattr(presence_router, hook, None)) hook: async_wrapper(getattr(presence_router, hook, None))
for hook in presence_router_methods for hook in presence_router_methods
} }

View File

@ -30,6 +30,7 @@ from typing import (
import attr import attr
import jinja2 import jinja2
from typing_extensions import ParamSpec
from twisted.internet import defer from twisted.internet import defer
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -129,6 +130,7 @@ if TYPE_CHECKING:
T = TypeVar("T") T = TypeVar("T")
P = ParamSpec("P")
""" """
This package defines the 'stable' API which can be used by extension modules which This package defines the 'stable' API which can be used by extension modules which
@ -799,9 +801,9 @@ class ModuleApi:
def run_db_interaction( def run_db_interaction(
self, self,
desc: str, desc: str,
func: Callable[..., T], func: Callable[P, T],
*args: Any, *args: P.args,
**kwargs: Any, **kwargs: P.kwargs,
) -> "defer.Deferred[T]": ) -> "defer.Deferred[T]":
"""Run a function with a database connection """Run a function with a database connection
@ -817,8 +819,9 @@ class ModuleApi:
Returns: Returns:
Deferred[object]: result of func Deferred[object]: result of func
""" """
# type-ignore: See https://github.com/python/mypy/issues/8862
return defer.ensureDeferred( return defer.ensureDeferred(
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
) )
def complete_sso_login( def complete_sso_login(
@ -1296,9 +1299,9 @@ class ModuleApi:
async def defer_to_thread( async def defer_to_thread(
self, self,
f: Callable[..., T], f: Callable[P, T],
*args: Any, *args: P.args,
**kwargs: Any, **kwargs: P.kwargs,
) -> T: ) -> T:
"""Runs the given function in a separate thread from Synapse's thread pool. """Runs the given function in a separate thread from Synapse's thread pool.

View File

@ -15,8 +15,6 @@
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
from twisted.web.server import Request
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
@ -97,7 +95,7 @@ class KnockRoomAliasServlet(RestServlet):
return 200, {"room_id": room_id} return 200, {"room_id": room_id}
def on_PUT( def on_PUT(
self, request: Request, room_identifier: str, txn_id: str self, request: SynapseRequest, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]: ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)

View File

@ -15,7 +15,9 @@
"""This module contains logic for storing HTTP PUT transactions. This is used """This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API.""" to ensure idempotency when performing PUTs using the REST API."""
import logging import logging
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple
from typing_extensions import ParamSpec
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.server import Request from twisted.web.server import Request
@ -32,6 +34,9 @@ logger = logging.getLogger(__name__)
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
P = ParamSpec("P")
class HttpTransactionCache: class HttpTransactionCache:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
@ -65,9 +70,9 @@ class HttpTransactionCache:
def fetch_or_execute_request( def fetch_or_execute_request(
self, self,
request: Request, request: Request,
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]], fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
*args: Any, *args: P.args,
**kwargs: Any, **kwargs: P.kwargs,
) -> Awaitable[Tuple[int, JsonDict]]: ) -> Awaitable[Tuple[int, JsonDict]]:
"""A helper function for fetch_or_execute which extracts """A helper function for fetch_or_execute which extracts
a transaction key from the given request. a transaction key from the given request.
@ -82,9 +87,9 @@ class HttpTransactionCache:
def fetch_or_execute( def fetch_or_execute(
self, self,
txn_key: str, txn_key: str,
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]], fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
*args: Any, *args: P.args,
**kwargs: Any, **kwargs: P.kwargs,
) -> Awaitable[Tuple[int, JsonDict]]: ) -> Awaitable[Tuple[int, JsonDict]]:
"""Fetches the response for this transaction, or executes the given function """Fetches the response for this transaction, or executes the given function
to produce a response for this transaction. to produce a response for this transaction.

View File

@ -192,7 +192,7 @@ class LoggingDatabaseConnection:
# The type of entry which goes on our after_callbacks and exception_callbacks lists. # The type of entry which goes on our after_callbacks and exception_callbacks lists.
_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]] _CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
@ -239,7 +239,9 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks self.exception_callbacks = exception_callbacks
def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any): def call_after(
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
) -> None:
"""Call the given callback on the main twisted thread after the transaction has """Call the given callback on the main twisted thread after the transaction has
finished. finished.
@ -256,11 +258,12 @@ class LoggingTransaction:
# LoggingTransaction isn't expecting there to be any callbacks; assert that # LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case. # is not the case.
assert self.after_callbacks is not None assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs)) # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
def call_on_exception( def call_on_exception(
self, callback: Callable[..., object], *args: Any, **kwargs: Any self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
): ) -> None:
"""Call the given callback on the main twisted thread after the transaction has """Call the given callback on the main twisted thread after the transaction has
failed. failed.
@ -274,7 +277,8 @@ class LoggingTransaction:
# LoggingTransaction isn't expecting there to be any callbacks; assert that # LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case. # is not the case.
assert self.exception_callbacks is not None assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs)) # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
self.exception_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
def fetchone(self) -> Optional[Tuple]: def fetchone(self) -> Optional[Tuple]:
return self.txn.fetchone() return self.txn.fetchone()
@ -549,9 +553,9 @@ class DatabasePool:
desc: str, desc: str,
after_callbacks: List[_CallbackListEntry], after_callbacks: List[_CallbackListEntry],
exception_callbacks: List[_CallbackListEntry], exception_callbacks: List[_CallbackListEntry],
func: Callable[..., R], func: Callable[Concatenate[LoggingTransaction, P], R],
*args: Any, *args: P.args,
**kwargs: Any, **kwargs: P.kwargs,
) -> R: ) -> R:
"""Start a new database transaction with the given connection. """Start a new database transaction with the given connection.
@ -581,7 +585,10 @@ class DatabasePool:
# will fail if we have to repeat the transaction. # will fail if we have to repeat the transaction.
# For now, we just log an error, and hope that it works on the first attempt. # For now, we just log an error, and hope that it works on the first attempt.
# TODO: raise an exception. # TODO: raise an exception.
for i, arg in enumerate(args):
# Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see
# https://github.com/python/mypy/pull/12668
for i, arg in enumerate(args): # type: ignore[arg-type, var-annotated]
if inspect.isgenerator(arg): if inspect.isgenerator(arg):
logger.error( logger.error(
"Programming error: generator passed to new_transaction as " "Programming error: generator passed to new_transaction as "
@ -589,7 +596,9 @@ class DatabasePool:
i, i,
func, func,
) )
for name, val in kwargs.items(): # Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see
# https://github.com/python/mypy/pull/12668
for name, val in kwargs.items(): # type: ignore[attr-defined]
if inspect.isgenerator(val): if inspect.isgenerator(val):
logger.error( logger.error(
"Programming error: generator passed to new_transaction as " "Programming error: generator passed to new_transaction as "

View File

@ -1648,8 +1648,12 @@ class PersistEventsStore:
txn.call_after(prefill) txn.call_after(prefill)
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
# Invalidate the caches for the redacted event, note that these caches """Invalidate the caches for the redacted event.
# are also cleared as part of event replication in _invalidate_caches_for_event.
Note that these caches are also cleared as part of event replication in
_invalidate_caches_for_event.
"""
assert event.redacts is not None
txn.call_after(self.store._invalidate_get_event_cache, event.redacts) txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,)) txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,)) txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))

View File

@ -42,7 +42,7 @@ from typing import (
) )
import attr import attr
from typing_extensions import AsyncContextManager, Literal from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
@ -237,9 +237,16 @@ async def concurrently_execute(
) )
P = ParamSpec("P")
R = TypeVar("R")
async def yieldable_gather_results( async def yieldable_gather_results(
func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any func: Callable[Concatenate[T, P], Awaitable[R]],
) -> List[T]: iter: Iterable[T],
*args: P.args,
**kwargs: P.kwargs,
) -> List[R]:
"""Executes the function with each argument concurrently. """Executes the function with each argument concurrently.
Args: Args:
@ -255,7 +262,15 @@ async def yieldable_gather_results(
try: try:
return await make_deferred_yieldable( return await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[run_in_background(func, item, *args, **kwargs) for item in iter], # type-ignore: mypy reports two errors:
# error: Argument 1 to "run_in_background" has incompatible type
# "Callable[[T, **P], Awaitable[R]]"; expected
# "Callable[[T, **P], Awaitable[R]]" [arg-type]
# error: Argument 2 to "run_in_background" has incompatible type
# "T"; expected "[T, **P.args]" [arg-type]
# The former looks like a mypy bug, and the latter looks like a
# false positive.
[run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type]
consumeErrors=True, consumeErrors=True,
) )
) )
@ -577,9 +592,6 @@ class ReadWriteLock:
return _ctx_manager() return _ctx_manager()
R = TypeVar("R")
def timeout_deferred( def timeout_deferred(
deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime
) -> "defer.Deferred[_T]": ) -> "defer.Deferred[_T]":

View File

@ -12,7 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Callable, Dict, List from typing import (
Any,
Awaitable,
Callable,
Dict,
Generic,
List,
Optional,
TypeVar,
Union,
)
from typing_extensions import ParamSpec
from twisted.internet import defer from twisted.internet import defer
@ -75,7 +87,11 @@ class Distributor:
run_as_background_process(name, self.signals[name].fire, *args, **kwargs) run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
class Signal: P = ParamSpec("P")
R = TypeVar("R")
class Signal(Generic[P]):
"""A Signal is a dispatch point that stores a list of callables as """A Signal is a dispatch point that stores a list of callables as
observers of it. observers of it.
@ -87,16 +103,16 @@ class Signal:
def __init__(self, name: str): def __init__(self, name: str):
self.name: str = name self.name: str = name
self.observers: List[Callable] = [] self.observers: List[Callable[P, Any]] = []
def observe(self, observer: Callable) -> None: def observe(self, observer: Callable[P, Any]) -> None:
"""Adds a new callable to the observer list which will be invoked by """Adds a new callable to the observer list which will be invoked by
the 'fire' method. the 'fire' method.
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]": def fire(self, *args: P.args, **kwargs: P.kwargs) -> "defer.Deferred[List[Any]]":
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
not an error to fire a signal with no observers. not an error to fire a signal with no observers.
@ -104,7 +120,7 @@ class Signal:
Returns a Deferred that will complete when all the observers have Returns a Deferred that will complete when all the observers have
completed.""" completed."""
async def do(observer: Callable[..., Any]) -> Any: async def do(observer: Callable[P, Union[R, Awaitable[R]]]) -> Optional[R]:
try: try:
return await maybe_awaitable(observer(*args, **kwargs)) return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e: except Exception as e:
@ -114,6 +130,7 @@ class Signal:
observer, observer,
e, e,
) )
return None
deferreds = [run_in_background(do, o) for o in self.observers] deferreds = [run_in_background(do, o) for o in self.observers]

View File

@ -15,10 +15,10 @@
import logging import logging
from functools import wraps from functools import wraps
from types import TracebackType from types import TracebackType
from typing import Any, Callable, Optional, Type, TypeVar, cast from typing import Awaitable, Callable, Optional, Type, TypeVar
from prometheus_client import Counter from prometheus_client import Counter
from typing_extensions import Protocol from typing_extensions import Concatenate, ParamSpec, Protocol
from synapse.logging.context import ( from synapse.logging.context import (
ContextResourceUsage, ContextResourceUsage,
@ -72,16 +72,21 @@ in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge(
) )
T = TypeVar("T", bound=Callable[..., Any]) P = ParamSpec("P")
R = TypeVar("R")
class HasClock(Protocol): class HasClock(Protocol):
clock: Clock clock: Clock
def measure_func(name: Optional[str] = None) -> Callable[[T], T]: def measure_func(
""" name: Optional[str] = None,
Used to decorate an async function with a `Measure` context manager. ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
"""Decorate an async method with a `Measure` context manager.
The Measure is created using `self.clock`; it should only be used to decorate
methods in classes defining an instance-level `clock` attribute.
Usage: Usage:
@ -97,18 +102,24 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
""" """
def wrapper(func: T) -> T: def wrapper(
func: Callable[Concatenate[HasClock, P], Awaitable[R]]
) -> Callable[P, Awaitable[R]]:
block_name = func.__name__ if name is None else name block_name = func.__name__ if name is None else name
@wraps(func) @wraps(func)
async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any: async def measured_func(self: HasClock, *args: P.args, **kwargs: P.kwargs) -> R:
with Measure(self.clock, block_name): with Measure(self.clock, block_name):
r = await func(self, *args, **kwargs) r = await func(self, *args, **kwargs)
return r return r
return cast(T, measured_func) # There are some shenanigans here, because we're decorating a method but
# explicitly making use of the `self` parameter. The key thing here is that the
# return type within the return type for `measure_func` itself describes how the
# decorated function will be called.
return measured_func # type: ignore[return-value]
return wrapper return wrapper # type: ignore[return-value]
class Measure: class Measure:

View File

@ -16,6 +16,8 @@ import functools
import sys import sys
from typing import Any, Callable, Generator, List, TypeVar, cast from typing import Any, Callable, Generator, List, TypeVar, cast
from typing_extensions import ParamSpec
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -25,6 +27,7 @@ _already_patched = False
T = TypeVar("T") T = TypeVar("T")
P = ParamSpec("P")
def do_patch() -> None: def do_patch() -> None:
@ -41,13 +44,13 @@ def do_patch() -> None:
return return
def new_inline_callbacks( def new_inline_callbacks(
f: Callable[..., Generator["Deferred[object]", object, T]] f: Callable[P, Generator["Deferred[object]", object, T]]
) -> Callable[..., "Deferred[T]"]: ) -> Callable[P, "Deferred[T]"]:
@functools.wraps(f) @functools.wraps(f)
def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]": def wrapped(*args: P.args, **kwargs: P.kwargs) -> "Deferred[T]":
start_context = current_context() start_context = current_context()
changes: List[str] = [] changes: List[str] = []
orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks( orig: Callable[P, "Deferred[T]"] = orig_inline_callbacks(
_check_yield_points(f, changes) _check_yield_points(f, changes)
) )
@ -115,7 +118,7 @@ def do_patch() -> None:
def _check_yield_points( def _check_yield_points(
f: Callable[..., Generator["Deferred[object]", object, T]], f: Callable[P, Generator["Deferred[object]", object, T]],
changes: List[str], changes: List[str],
) -> Callable: ) -> Callable:
"""Wraps a generator that is about to be passed to defer.inlineCallbacks """Wraps a generator that is about to be passed to defer.inlineCallbacks
@ -138,7 +141,7 @@ def _check_yield_points(
@functools.wraps(f) @functools.wraps(f)
def check_yield_points_inner( def check_yield_points_inner(
*args: Any, **kwargs: Any *args: P.args, **kwargs: P.kwargs
) -> Generator["Deferred[object]", object, T]: ) -> Generator["Deferred[object]", object, T]:
gen = f(*args, **kwargs) gen = f(*args, **kwargs)