Add missing type hints to `synapse.logging.context` (#11556)

pull/11590/head
Sean Quah 2021-12-14 17:35:28 +00:00 committed by GitHub
parent 2519beaad2
commit 0147b3de20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 215 additions and 122 deletions

1
changelog.d/11556.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to `synapse.logging.context`.

View File

@ -167,6 +167,9 @@ disallow_untyped_defs = True
[mypy-synapse.http.server] [mypy-synapse.http.server]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.logging.context]
disallow_untyped_defs = True
[mypy-synapse.metrics.*] [mypy-synapse.metrics.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View File

@ -17,11 +17,12 @@
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
from twisted.internet import protocol from twisted.internet import protocol
from twisted.internet.defer import Deferred
class RedisProtocol(protocol.Protocol): class RedisProtocol(protocol.Protocol):
def publish(self, channel: str, message: bytes): ... def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ... def ping(self) -> "Deferred[None]": ...
async def set( def set(
self, self,
key: str, key: str,
value: Any, value: Any,
@ -29,8 +30,8 @@ class RedisProtocol(protocol.Protocol):
pexpire: Optional[int] = None, pexpire: Optional[int] = None,
only_if_not_exists: bool = False, only_if_not_exists: bool = False,
only_if_exists: bool = False, only_if_exists: bool = False,
) -> None: ... ) -> "Deferred[None]": ...
async def get(self, key: str) -> Any: ... def get(self, key: str) -> "Deferred[Any]": ...
class SubscriberProtocol(RedisProtocol): class SubscriberProtocol(RedisProtocol):
def __init__(self, *args, **kwargs): ... def __init__(self, *args, **kwargs): ...

View File

@ -30,7 +30,6 @@ from typing import (
from prometheus_client import Counter, Gauge, Histogram from prometheus_client import Counter, Gauge, Histogram
from twisted.internet import defer
from twisted.internet.abstract import isIPAddress from twisted.internet.abstract import isIPAddress
from twisted.python import failure from twisted.python import failure
@ -67,7 +66,7 @@ from synapse.replication.http.federation import (
from synapse.storage.databases.main.lock import Lock from synapse.storage.databases.main.lock import Lock
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_server_name from synapse.util.stringutils import parse_server_name
@ -360,13 +359,13 @@ class FederationServer(FederationBase):
# want to block things like to device messages from reaching clients # want to block things like to device messages from reaching clients
# behind the potentially expensive handling of PDUs. # behind the potentially expensive handling of PDUs.
pdu_results, _ = await make_deferred_yieldable( pdu_results, _ = await make_deferred_yieldable(
defer.gatherResults( gather_results(
[ (
run_in_background( run_in_background(
self._handle_pdus_in_txn, origin, transaction, request_time self._handle_pdus_in_txn, origin, transaction, request_time
), ),
run_in_background(self._handle_edus_in_txn, origin, transaction), run_in_background(self._handle_edus_in_txn, origin, transaction),
], ),
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
) )

View File

@ -360,31 +360,34 @@ class FederationHandler:
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
states = await make_deferred_yieldable( states_list = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
) )
) )
# dict[str, dict[tuple, str]], a map from event_id to state map of # A map from event_id to state map of event_ids.
# event_ids. state_ids: Dict[str, StateMap[str]] = dict(
states = dict(zip(event_ids, [s.state for s in states])) zip(event_ids, [s.state for s in states_list])
)
state_map = await self.store.get_events( state_map = await self.store.get_events(
[e_id for ids in states.values() for e_id in ids.values()], [e_id for ids in state_ids.values() for e_id in ids.values()],
get_prev_content=False, get_prev_content=False,
) )
states = {
# A map from event_id to state map of events.
state_events: Dict[str, StateMap[EventBase]] = {
key: { key: {
k: state_map[e_id] k: state_map[e_id]
for k, e_id in state_dict.items() for k, e_id in state_dict.items()
if e_id in state_map if e_id in state_map
} }
for key, state_dict in states.items() for key, state_dict in state_ids.items()
} }
for e_id in event_ids: for e_id in event_ids:
likely_extremeties_domains = get_domains_from_state(states[e_id]) likely_extremeties_domains = get_domains_from_state(state_events[e_id])
success = await try_backfill( success = await try_backfill(
[ [

View File

@ -13,21 +13,27 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from twisted.internet import defer
from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.receipts import ReceiptEventSource
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID from synapse.types import (
JsonDict,
Requester,
RoomStreamToken,
StateMap,
StreamToken,
UserID,
)
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -190,14 +196,13 @@ class InitialSyncHandler:
) )
deferred_room_state = run_in_background( deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id] self.state_store.get_state_for_events, [event.event_id]
) ).addCallback(
deferred_room_state.addCallback( lambda states: cast(StateMap[EventBase], states[event.event_id])
lambda states: states[event.event_id]
) )
(messages, token), current_state = await make_deferred_yieldable( (messages, token), current_state = await make_deferred_yieldable(
defer.gatherResults( gather_results(
[ (
run_in_background( run_in_background(
self.store.get_recent_events_for_room, self.store.get_recent_events_for_room,
event.room_id, event.room_id,
@ -205,7 +210,7 @@ class InitialSyncHandler:
end_token=room_end_token, end_token=room_end_token,
), ),
deferred_room_state, deferred_room_state,
] )
) )
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@ -454,8 +459,8 @@ class InitialSyncHandler:
return receipts return receipts
presence, receipts, (messages, token) = await make_deferred_yieldable( presence, receipts, (messages, token) = await make_deferred_yieldable(
defer.gatherResults( gather_results(
[ (
run_in_background(get_presence), run_in_background(get_presence),
run_in_background(get_receipts), run_in_background(get_receipts),
run_in_background( run_in_background(
@ -464,7 +469,7 @@ class InitialSyncHandler:
limit=limit, limit=limit,
end_token=now_token.room_key, end_token=now_token.room_key,
), ),
], ),
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
) )

View File

@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall from twisted.internet.interfaces import IDelayedCall
from synapse import event_auth from synapse import event_auth
@ -57,7 +56,7 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.util import json_decoder, json_encoder, log_failure from synapse.util import json_decoder, json_encoder, log_failure
from synapse.util.async_helpers import Linearizer, unwrapFirstError from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -1168,9 +1167,9 @@ class EventCreationHandler:
# We now persist the event (and update the cache in parallel, since we # We now persist the event (and update the cache in parallel, since we
# don't want to block on it). # don't want to block on it).
result = await make_deferred_yieldable( result, _ = await make_deferred_yieldable(
defer.gatherResults( gather_results(
[ (
run_in_background( run_in_background(
self._persist_event, self._persist_event,
requester=requester, requester=requester,
@ -1182,12 +1181,12 @@ class EventCreationHandler:
run_in_background( run_in_background(
self.cache_joined_hosts_for_event, event, context self.cache_joined_hosts_for_event, event, context
).addErrback(log_failure, "cache_joined_hosts_for_event failed"), ).addErrback(log_failure, "cache_joined_hosts_for_event failed"),
], ),
consumeErrors=True, consumeErrors=True,
) )
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
return result[0] return result
async def _persist_event( async def _persist_event(
self, self,

View File

@ -25,6 +25,7 @@ from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import ( from twisted.internet.interfaces import (
IProtocol,
IProtocolFactory, IProtocolFactory,
IReactorCore, IReactorCore,
IStreamClientEndpoint, IStreamClientEndpoint,
@ -309,12 +310,14 @@ class MatrixHostnameEndpoint:
self._srv_resolver = srv_resolver self._srv_resolver = srv_resolver
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred: def connect(
self, protocol_factory: IProtocolFactory
) -> "defer.Deferred[IProtocol]":
"""Implements IStreamClientEndpoint interface""" """Implements IStreamClientEndpoint interface"""
return run_in_background(self._do_connect, protocol_factory) return run_in_background(self._do_connect, protocol_factory)
async def _do_connect(self, protocol_factory: IProtocolFactory) -> None: async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
first_exception = None first_exception = None
server_list = await self._resolve_server() server_list = await self._resolve_server()

View File

@ -22,20 +22,33 @@ them.
See doc/log_contexts.rst for details on how this works. See doc/log_contexts.rst for details on how this works.
""" """
import inspect
import logging import logging
import threading import threading
import typing import typing
import warnings import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
)
import attr import attr
from typing_extensions import Literal from typing_extensions import Literal
from twisted.internet import defer, threads from twisted.internet import defer, threads
from twisted.python.threadpool import ThreadPool
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.logging.scopecontextmanager import _LogContextScope from synapse.logging.scopecontextmanager import _LogContextScope
from synapse.types import ISynapseReactor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -66,7 +79,7 @@ except Exception:
# a hook which can be set during testing to assert that we aren't abusing logcontexts. # a hook which can be set during testing to assert that we aren't abusing logcontexts.
def logcontext_error(msg: str): def logcontext_error(msg: str) -> None:
logger.warning(msg) logger.warning(msg)
@ -223,22 +236,19 @@ class _Sentinel:
def __str__(self) -> str: def __str__(self) -> str:
return "sentinel" return "sentinel"
def copy_to(self, record): def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
pass pass
def start(self, rusage: "Optional[resource.struct_rusage]"): def stop(self, rusage: "Optional[resource.struct_rusage]") -> None:
pass pass
def stop(self, rusage: "Optional[resource.struct_rusage]"): def add_database_transaction(self, duration_sec: float) -> None:
pass pass
def add_database_transaction(self, duration_sec): def add_database_scheduled(self, sched_sec: float) -> None:
pass pass
def add_database_scheduled(self, sched_sec): def record_event_fetch(self, event_count: int) -> None:
pass
def record_event_fetch(self, event_count):
pass pass
def __bool__(self) -> Literal[False]: def __bool__(self) -> Literal[False]:
@ -379,7 +389,12 @@ class LoggingContext:
) )
return self return self
def __exit__(self, type, value, traceback) -> None: def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Restore the logging context in thread local storage to the state it """Restore the logging context in thread local storage to the state it
was before this context was entered. was before this context was entered.
Returns: Returns:
@ -399,17 +414,6 @@ class LoggingContext:
# recorded against the correct metrics. # recorded against the correct metrics.
self.finished = True self.finished = True
def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or
another LoggingContext
"""
# we track the current request
record.request = self.request
# we also track the current scope:
record.scope = self.scope
def start(self, rusage: "Optional[resource.struct_rusage]") -> None: def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
""" """
Record that this logcontext is currently running. Record that this logcontext is currently running.
@ -626,7 +630,12 @@ class PreserveLoggingContext:
def __enter__(self) -> None: def __enter__(self) -> None:
self._old_context = set_current_context(self._new_context) self._old_context = set_current_context(self._new_context)
def __exit__(self, type, value, traceback) -> None: def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
context = set_current_context(self._old_context) context = set_current_context(self._old_context)
if context != self._new_context: if context != self._new_context:
@ -711,16 +720,61 @@ def nested_logging_context(suffix: str) -> LoggingContext:
) )
def preserve_fn(f): R = TypeVar("R")
@overload
def preserve_fn( # type: ignore[misc]
f: Callable[..., Awaitable[R]],
) -> Callable[..., "defer.Deferred[R]"]:
# The `type: ignore[misc]` above suppresses
# "Overloaded function signatures 1 and 2 overlap with incompatible return types"
...
@overload
def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]:
...
def preserve_fn(
f: Union[
Callable[..., R],
Callable[..., Awaitable[R]],
]
) -> Callable[..., "defer.Deferred[R]"]:
"""Function decorator which wraps the function with run_in_background""" """Function decorator which wraps the function with run_in_background"""
def g(*args, **kwargs): def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]":
return run_in_background(f, *args, **kwargs) return run_in_background(f, *args, **kwargs)
return g return g
def run_in_background(f, *args, **kwargs) -> defer.Deferred: @overload
def run_in_background( # type: ignore[misc]
f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any
) -> "defer.Deferred[R]":
# The `type: ignore[misc]` above suppresses
# "Overloaded function signatures 1 and 2 overlap with incompatible return types"
...
@overload
def run_in_background(
f: Callable[..., R], *args: Any, **kwargs: Any
) -> "defer.Deferred[R]":
...
def run_in_background(
f: Union[
Callable[..., R],
Callable[..., Awaitable[R]],
],
*args: Any,
**kwargs: Any,
) -> "defer.Deferred[R]":
"""Calls a function, ensuring that the current context is restored after """Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the return from the function, and that the sentinel context is set once the
deferred returned by the function completes. deferred returned by the function completes.
@ -751,6 +805,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
# At this point we should have a Deferred, if not then f was a synchronous # At this point we should have a Deferred, if not then f was a synchronous
# function, wrap it in a Deferred for consistency. # function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred): if not isinstance(res, defer.Deferred):
# `res` is not a `Deferred` and not a `Coroutine`.
# There are no other types of `Awaitable`s we expect to encounter in Synapse.
assert not isinstance(res, Awaitable)
return defer.succeed(res) return defer.succeed(res)
if res.called and not res.paused: if res.called and not res.paused:
@ -778,13 +836,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
return res return res
def make_deferred_yieldable(deferred): T = TypeVar("T")
"""Given a deferred (or coroutine), make it follow the Synapse logcontext
rules:
If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
result/failure). """Given a deferred, make it follow the Synapse logcontext rules:
If the deferred has completed, essentially does nothing (just returns another
completed deferred with the result/failure).
If the deferred has not yet completed, resets the logcontext before If the deferred has not yet completed, resets the logcontext before
returning a deferred. Then, when the deferred completes, restores the returning a deferred. Then, when the deferred completes, restores the
@ -792,16 +851,6 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.) (This is more-or-less the opposite operation to run_in_background.)
""" """
if inspect.isawaitable(deferred):
# If we're given a coroutine we convert it to a deferred so that we
# run it and find out if it immediately finishes, it it does then we
# don't need to fiddle with log contexts at all and can return
# immediately.
deferred = defer.ensureDeferred(deferred)
if not isinstance(deferred, defer.Deferred):
return deferred
if deferred.called and not deferred.paused: if deferred.called and not deferred.paused:
# it looks like this deferred is ready to run any callbacks we give it # it looks like this deferred is ready to run any callbacks we give it
# immediately. We may as well optimise out the logcontext faffery. # immediately. We may as well optimise out the logcontext faffery.
@ -823,7 +872,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
return result return result
def defer_to_thread(reactor, f, *args, **kwargs): def defer_to_thread(
reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any
) -> "defer.Deferred[R]":
""" """
Calls the function `f` using a thread from the reactor's default threadpool and Calls the function `f` using a thread from the reactor's default threadpool and
returns the result as a Deferred. returns the result as a Deferred.
@ -855,7 +906,13 @@ def defer_to_thread(reactor, f, *args, **kwargs):
return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): def defer_to_threadpool(
reactor: "ISynapseReactor",
threadpool: ThreadPool,
f: Callable[..., R],
*args: Any,
**kwargs: Any,
) -> "defer.Deferred[R]":
""" """
A wrapper for twisted.internet.threads.deferToThreadpool, which handles A wrapper for twisted.internet.threads.deferToThreadpool, which handles
logcontexts correctly. logcontexts correctly.
@ -897,7 +954,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
assert isinstance(curr_context, LoggingContext) assert isinstance(curr_context, LoggingContext)
parent_context = curr_context parent_context = curr_context
def g(): def g() -> R:
with LoggingContext(str(curr_context), parent_context=parent_context): with LoggingContext(str(curr_context), parent_context=parent_context):
return f(*args, **kwargs) return f(*args, **kwargs)

View File

@ -30,9 +30,11 @@ from typing import (
Iterator, Iterator,
Optional, Optional,
Set, Set,
Tuple,
TypeVar, TypeVar,
Union, Union,
cast, cast,
overload,
) )
import attr import attr
@ -234,6 +236,59 @@ def yieldable_gather_results(
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
@overload
def gather_results(
deferredList: Tuple[()], consumeErrors: bool = ...
) -> "defer.Deferred[Tuple[()]]":
...
@overload
def gather_results(
deferredList: Tuple["defer.Deferred[T1]"],
consumeErrors: bool = ...,
) -> "defer.Deferred[Tuple[T1]]":
...
@overload
def gather_results(
deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"],
consumeErrors: bool = ...,
) -> "defer.Deferred[Tuple[T1, T2]]":
...
@overload
def gather_results(
deferredList: Tuple[
"defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]"
],
consumeErrors: bool = ...,
) -> "defer.Deferred[Tuple[T1, T2, T3]]":
...
def gather_results( # type: ignore[misc]
deferredList: Tuple["defer.Deferred[T1]", ...],
consumeErrors: bool = False,
) -> "defer.Deferred[Tuple[T1, ...]]":
"""Combines a tuple of `Deferred`s into a single `Deferred`.
Wraps `defer.gatherResults` to provide type annotations that support heterogenous
lists of `Deferred`s.
"""
# The `type: ignore[misc]` above suppresses
# "Overloaded function implementation cannot produce return type of signature 1/2/3"
deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors)
return deferred.addCallback(tuple)
@attr.s(slots=True) @attr.s(slots=True)
class _LinearizerEntry: class _LinearizerEntry:
# The number of things executing. # The number of things executing.
@ -352,7 +407,7 @@ class Linearizer:
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
new_defer = make_deferred_yieldable(defer.Deferred()) new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
entry.deferreds[new_defer] = 1 entry.deferreds[new_defer] = 1
def cb(_r: None) -> "defer.Deferred[None]": def cb(_r: None) -> "defer.Deferred[None]":

View File

@ -76,6 +76,7 @@ class CachedCall(Generic[TV]):
# Fire off the callable now if this is our first time # Fire off the callable now if this is our first time
if not self._deferred: if not self._deferred:
assert self._callable is not None
self._deferred = run_in_background(self._callable) self._deferred = run_in_background(self._callable)
# we will never need the callable again, so make sure it can be GCed # we will never need the callable again, so make sure it can be GCed

View File

@ -142,6 +142,7 @@ class BackgroundFileConsumer:
def wait(self) -> "Deferred[None]": def wait(self) -> "Deferred[None]":
"""Returns a deferred that resolves when finished writing to file""" """Returns a deferred that resolves when finished writing to file"""
assert self._finished_deferred is not None
return make_deferred_yieldable(self._finished_deferred) return make_deferred_yieldable(self._finished_deferred)
def _resume_paused_producer(self) -> None: def _resume_paused_producer(self) -> None:

View File

@ -152,46 +152,11 @@ class LoggingContextTestCase(unittest.TestCase):
# now it should be restored # now it should be restored
self._check_test_key("one") self._check_test_key("one")
@defer.inlineCallbacks
def test_make_deferred_yieldable_on_non_deferred(self):
"""Check that make_deferred_yieldable does the right thing when its
argument isn't actually a deferred"""
with LoggingContext("one"):
d1 = make_deferred_yieldable("bum")
self._check_test_key("one")
r = yield d1
self.assertEqual(r, "bum")
self._check_test_key("one")
def test_nested_logging_context(self): def test_nested_logging_context(self):
with LoggingContext("foo"): with LoggingContext("foo"):
nested_context = nested_logging_context(suffix="bar") nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.name, "foo-bar") self.assertEqual(nested_context.name, "foo-bar")
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_await(self):
# an async function which returns an incomplete coroutine, but doesn't
# follow the synapse rules.
async def blocking_function():
d = defer.Deferred()
reactor.callLater(0, d.callback, None)
await d
sentinel_context = current_context()
with LoggingContext("one"):
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
yield d1
# now it should be restored
self._check_test_key("one")
# a function which returns a deferred which has been "called", but # a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on # which had a function which returned another incomplete deferred on