Improve type hints for cached decorator. (#15658)

The cached decorators always return a Deferred, which was not
properly propagated. It was close enough when wrapping coroutines,
but failed if a bare function was wrapped.
pull/15520/head
Patrick Cloke 2023-05-24 08:59:31 -04:00 committed by GitHub
parent 379eb2d7ab
commit 1f55c04cbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 73 additions and 63 deletions

1
changelog.d/15658.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -18,10 +18,11 @@ can crop up, e.g the cache descriptors.
from typing import Callable, Optional, Type from typing import Callable, Optional, Type
from mypy.erasetype import remove_instance_last_known_values
from mypy.nodes import ARG_NAMED_OPT from mypy.nodes import ARG_NAMED_OPT
from mypy.plugin import MethodSigContext, Plugin from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self from mypy.typeops import bind_self
from mypy.types import CallableType, NoneType, UnionType from mypy.types import CallableType, Instance, NoneType, UnionType
class SynapsePlugin(Plugin): class SynapsePlugin(Plugin):
@ -92,10 +93,41 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
arg_names.append("on_invalidate") arg_names.append("on_invalidate")
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg. arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
# Finally we ensure the return type is a Deferred.
if (
isinstance(signature.ret_type, Instance)
and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred"
):
# If it is already a Deferred, nothing to do.
ret_type = signature.ret_type
else:
ret_arg = None
if isinstance(signature.ret_type, Instance):
# If a coroutine, wrap the coroutine's return type in a Deferred.
if signature.ret_type.type.fullname == "typing.Coroutine":
ret_arg = signature.ret_type.args[2]
# If an awaitable, wrap the awaitable's final value in a Deferred.
elif signature.ret_type.type.fullname == "typing.Awaitable":
ret_arg = signature.ret_type.args[0]
# Otherwise, wrap the return value in a Deferred.
if ret_arg is None:
ret_arg = signature.ret_type
# This should be able to use ctx.api.named_generic_type, but that doesn't seem
# to find the correct symbol for anything more than 1 module deep.
#
# modules is not part of CheckerPluginInterface. The following is a combination
# of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo.
sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined]
ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])
signature = signature.copy_modified( signature = signature.copy_modified(
arg_types=arg_types, arg_types=arg_types,
arg_names=arg_names, arg_names=arg_names,
arg_kinds=arg_kinds, arg_kinds=arg_kinds,
ret_type=ret_type,
) )
return signature return signature

View File

@ -1099,7 +1099,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
# `get_joined_hosts` is called with the "current" state group for the # `get_joined_hosts` is called with the "current" state group for the
# room, and so consecutive calls will be for consecutive state groups # room, and so consecutive calls will be for consecutive state groups
# which point to the previous state group. # which point to the previous state group.
cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc] cache = await self._get_joined_hosts_cache(room_id)
# If the state group in the cache matches, we already have the data we need. # If the state group in the cache matches, we already have the data we need.
if state_entry.state_group == cache.state_group: if state_entry.state_group == cache.state_group:

View File

@ -220,7 +220,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable self.iterable = iterable
self.prune_unread_entries = prune_unread_entries self.prune_unread_entries = prune_unread_entries
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: def __get__(
self, obj: Optional[Any], owner: Optional[Type]
) -> Callable[..., "defer.Deferred[Any]"]:
cache: DeferredCache[CacheKey, Any] = DeferredCache( cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.name, name=self.name,
max_entries=self.max_entries, max_entries=self.max_entries,
@ -232,7 +234,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
get_cache_key = self.cache_key_builder get_cache_key = self.cache_key_builder
@functools.wraps(self.orig) @functools.wraps(self.orig)
def _wrapped(*args: Any, **kwargs: Any) -> Any: def _wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Any]":
# If we're passed a cache_context then we'll want to call its invalidate() # If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated # whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re import re
from typing import Generator from typing import Any, Generator
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer from twisted.internet import defer
@ -49,93 +49,81 @@ class ApplicationServiceTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_user_id_prefix_match( def test_regex_user_id_prefix_match(
self, self,
) -> Generator["defer.Deferred[object]", object, None]: ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue( self.assertTrue(
( (
yield defer.ensureDeferred( yield self.service.is_interested_in_event(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store self.event.event_id, self.event, self.store
) )
) )
) )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_user_id_prefix_no_match( def test_regex_user_id_prefix_no_match(
self, self,
) -> Generator["defer.Deferred[object]", object, None]: ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.assertFalse( self.assertFalse(
( (
yield defer.ensureDeferred( yield self.service.is_interested_in_event(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store self.event.event_id, self.event, self.store
) )
) )
) )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_member_is_checked( def test_regex_room_member_is_checked(
self, self,
) -> Generator["defer.Deferred[object]", object, None]: ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member" self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org" self.event.state_key = "@irc_foobar:matrix.org"
self.assertTrue( self.assertTrue(
( (
yield defer.ensureDeferred( yield self.service.is_interested_in_event(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store self.event.event_id, self.event, self.store
) )
) )
) )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_id_match( def test_regex_room_id_match(
self, self,
) -> Generator["defer.Deferred[object]", object, None]: ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org") _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue( self.assertTrue(
( (
yield defer.ensureDeferred( yield self.service.is_interested_in_event(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store self.event.event_id, self.event, self.store
) )
) )
) )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_id_no_match( def test_regex_room_id_no_match(
self, self,
) -> Generator["defer.Deferred[object]", object, None]: ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org") _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse( self.assertFalse(
( (
yield defer.ensureDeferred( yield self.service.is_interested_in_event(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store self.event.event_id, self.event, self.store
) )
) )
) )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_alias_match( def test_regex_alias_match(self) -> Generator["defer.Deferred[Any]", object, None]:
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
@ -145,13 +133,11 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.store.get_local_users_in_room = simple_async_mock([]) self.store.get_local_users_in_room = simple_async_mock([])
self.assertTrue( self.assertTrue(
( (
yield defer.ensureDeferred( yield self.service.is_interested_in_event(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store self.event.event_id, self.event, self.store
) )
) )
) )
)
def test_non_exclusive_alias(self) -> None: def test_non_exclusive_alias(self) -> None:
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
@ -192,7 +178,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_alias_no_match( def test_regex_alias_no_match(
self, self,
) -> Generator["defer.Deferred[object]", object, None]: ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
@ -213,7 +199,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_multiple_matches( def test_regex_multiple_matches(
self, self,
) -> Generator["defer.Deferred[object]", object, None]: ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
@ -223,18 +209,14 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.store.get_local_users_in_room = simple_async_mock([]) self.store.get_local_users_in_room = simple_async_mock([])
self.assertTrue( self.assertTrue(
( (
yield defer.ensureDeferred( yield self.service.is_interested_in_event(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store self.event.event_id, self.event, self.store
) )
) )
) )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_interested_in_self( def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, None]:
self,
) -> Generator["defer.Deferred[object]", object, None]:
# make sure invites get through # make sure invites get through
self.service.sender = "@appservice:name" self.service.sender = "@appservice:name"
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
@ -243,18 +225,14 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.state_key = self.service.sender self.event.state_key = self.service.sender
self.assertTrue( self.assertTrue(
( (
yield defer.ensureDeferred( yield self.service.is_interested_in_event(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store self.event.event_id, self.event, self.store
) )
) )
) )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_member_list_match( def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
# Note that @irc_fo:here is the AS user. # Note that @irc_fo:here is the AS user.
self.store.get_local_users_in_room = simple_async_mock( self.store.get_local_users_in_room = simple_async_mock(
@ -265,10 +243,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.sender = "@xmpp_foobar:matrix.org" self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue( self.assertTrue(
( (
yield defer.ensureDeferred( yield self.service.is_interested_in_event(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store self.event.event_id, self.event, self.store
) )
) )
) )
)

View File

@ -33,15 +33,14 @@ class TransactionStoreTestCase(HomeserverTestCase):
destination retries, as well as testing tht we can set and get destination retries, as well as testing tht we can set and get
correctly. correctly.
""" """
d = self.store.get_destination_retry_timings("example.com") r = self.get_success(self.store.get_destination_retry_timings("example.com"))
r = self.get_success(d)
self.assertIsNone(r) self.assertIsNone(r)
d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) self.get_success(
self.get_success(d) self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
)
d = self.store.get_destination_retry_timings("example.com") r = self.get_success(self.store.get_destination_retry_timings("example.com"))
r = self.get_success(d)
self.assertEqual( self.assertEqual(
DestinationRetryTimings( DestinationRetryTimings(