Add missing type hints to `synapse.logging.context` (#11556)
parent
2519beaad2
commit
0147b3de20
|
@ -0,0 +1 @@
|
||||||
|
Add missing type hints to `synapse.logging.context`.
|
3
mypy.ini
3
mypy.ini
|
@ -167,6 +167,9 @@ disallow_untyped_defs = True
|
||||||
[mypy-synapse.http.server]
|
[mypy-synapse.http.server]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.logging.context]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.metrics.*]
|
[mypy-synapse.metrics.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
|
|
@ -17,11 +17,12 @@
|
||||||
from typing import Any, List, Optional, Type, Union
|
from typing import Any, List, Optional, Type, Union
|
||||||
|
|
||||||
from twisted.internet import protocol
|
from twisted.internet import protocol
|
||||||
|
from twisted.internet.defer import Deferred
|
||||||
|
|
||||||
class RedisProtocol(protocol.Protocol):
|
class RedisProtocol(protocol.Protocol):
|
||||||
def publish(self, channel: str, message: bytes): ...
|
def publish(self, channel: str, message: bytes): ...
|
||||||
async def ping(self) -> None: ...
|
def ping(self) -> "Deferred[None]": ...
|
||||||
async def set(
|
def set(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
value: Any,
|
value: Any,
|
||||||
|
@ -29,8 +30,8 @@ class RedisProtocol(protocol.Protocol):
|
||||||
pexpire: Optional[int] = None,
|
pexpire: Optional[int] = None,
|
||||||
only_if_not_exists: bool = False,
|
only_if_not_exists: bool = False,
|
||||||
only_if_exists: bool = False,
|
only_if_exists: bool = False,
|
||||||
) -> None: ...
|
) -> "Deferred[None]": ...
|
||||||
async def get(self, key: str) -> Any: ...
|
def get(self, key: str) -> "Deferred[Any]": ...
|
||||||
|
|
||||||
class SubscriberProtocol(RedisProtocol):
|
class SubscriberProtocol(RedisProtocol):
|
||||||
def __init__(self, *args, **kwargs): ...
|
def __init__(self, *args, **kwargs): ...
|
||||||
|
|
|
@ -30,7 +30,6 @@ from typing import (
|
||||||
|
|
||||||
from prometheus_client import Counter, Gauge, Histogram
|
from prometheus_client import Counter, Gauge, Histogram
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.internet.abstract import isIPAddress
|
from twisted.internet.abstract import isIPAddress
|
||||||
from twisted.python import failure
|
from twisted.python import failure
|
||||||
|
|
||||||
|
@ -67,7 +66,7 @@ from synapse.replication.http.federation import (
|
||||||
from synapse.storage.databases.main.lock import Lock
|
from synapse.storage.databases.main.lock import Lock
|
||||||
from synapse.types import JsonDict, get_domain_from_id
|
from synapse.types import JsonDict, get_domain_from_id
|
||||||
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
||||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.util.stringutils import parse_server_name
|
from synapse.util.stringutils import parse_server_name
|
||||||
|
|
||||||
|
@ -360,13 +359,13 @@ class FederationServer(FederationBase):
|
||||||
# want to block things like to device messages from reaching clients
|
# want to block things like to device messages from reaching clients
|
||||||
# behind the potentially expensive handling of PDUs.
|
# behind the potentially expensive handling of PDUs.
|
||||||
pdu_results, _ = await make_deferred_yieldable(
|
pdu_results, _ = await make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
gather_results(
|
||||||
[
|
(
|
||||||
run_in_background(
|
run_in_background(
|
||||||
self._handle_pdus_in_txn, origin, transaction, request_time
|
self._handle_pdus_in_txn, origin, transaction, request_time
|
||||||
),
|
),
|
||||||
run_in_background(self._handle_edus_in_txn, origin, transaction),
|
run_in_background(self._handle_edus_in_txn, origin, transaction),
|
||||||
],
|
),
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
)
|
)
|
||||||
|
|
|
@ -360,31 +360,34 @@ class FederationHandler:
|
||||||
|
|
||||||
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
||||||
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
|
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
|
||||||
states = await make_deferred_yieldable(
|
states_list = await make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
defer.gatherResults(
|
||||||
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
|
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# dict[str, dict[tuple, str]], a map from event_id to state map of
|
# A map from event_id to state map of event_ids.
|
||||||
# event_ids.
|
state_ids: Dict[str, StateMap[str]] = dict(
|
||||||
states = dict(zip(event_ids, [s.state for s in states]))
|
zip(event_ids, [s.state for s in states_list])
|
||||||
|
)
|
||||||
|
|
||||||
state_map = await self.store.get_events(
|
state_map = await self.store.get_events(
|
||||||
[e_id for ids in states.values() for e_id in ids.values()],
|
[e_id for ids in state_ids.values() for e_id in ids.values()],
|
||||||
get_prev_content=False,
|
get_prev_content=False,
|
||||||
)
|
)
|
||||||
states = {
|
|
||||||
|
# A map from event_id to state map of events.
|
||||||
|
state_events: Dict[str, StateMap[EventBase]] = {
|
||||||
key: {
|
key: {
|
||||||
k: state_map[e_id]
|
k: state_map[e_id]
|
||||||
for k, e_id in state_dict.items()
|
for k, e_id in state_dict.items()
|
||||||
if e_id in state_map
|
if e_id in state_map
|
||||||
}
|
}
|
||||||
for key, state_dict in states.items()
|
for key, state_dict in state_ids.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
for e_id in event_ids:
|
for e_id in event_ids:
|
||||||
likely_extremeties_domains = get_domains_from_state(states[e_id])
|
likely_extremeties_domains = get_domains_from_state(state_events[e_id])
|
||||||
|
|
||||||
success = await try_backfill(
|
success = await try_backfill(
|
||||||
[
|
[
|
||||||
|
|
|
@ -13,21 +13,27 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import EduTypes, EventTypes, Membership
|
from synapse.api.constants import EduTypes, EventTypes, Membership
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.handlers.presence import format_user_presence_state
|
from synapse.handlers.presence import format_user_presence_state
|
||||||
from synapse.handlers.receipts import ReceiptEventSource
|
from synapse.handlers.receipts import ReceiptEventSource
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.storage.roommember import RoomsForUser
|
from synapse.storage.roommember import RoomsForUser
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
|
from synapse.types import (
|
||||||
|
JsonDict,
|
||||||
|
Requester,
|
||||||
|
RoomStreamToken,
|
||||||
|
StateMap,
|
||||||
|
StreamToken,
|
||||||
|
UserID,
|
||||||
|
)
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async_helpers import concurrently_execute
|
from synapse.util.async_helpers import concurrently_execute, gather_results
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
|
@ -190,14 +196,13 @@ class InitialSyncHandler:
|
||||||
)
|
)
|
||||||
deferred_room_state = run_in_background(
|
deferred_room_state = run_in_background(
|
||||||
self.state_store.get_state_for_events, [event.event_id]
|
self.state_store.get_state_for_events, [event.event_id]
|
||||||
)
|
).addCallback(
|
||||||
deferred_room_state.addCallback(
|
lambda states: cast(StateMap[EventBase], states[event.event_id])
|
||||||
lambda states: states[event.event_id]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
(messages, token), current_state = await make_deferred_yieldable(
|
(messages, token), current_state = await make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
gather_results(
|
||||||
[
|
(
|
||||||
run_in_background(
|
run_in_background(
|
||||||
self.store.get_recent_events_for_room,
|
self.store.get_recent_events_for_room,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
|
@ -205,7 +210,7 @@ class InitialSyncHandler:
|
||||||
end_token=room_end_token,
|
end_token=room_end_token,
|
||||||
),
|
),
|
||||||
deferred_room_state,
|
deferred_room_state,
|
||||||
]
|
)
|
||||||
)
|
)
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
@ -454,8 +459,8 @@ class InitialSyncHandler:
|
||||||
return receipts
|
return receipts
|
||||||
|
|
||||||
presence, receipts, (messages, token) = await make_deferred_yieldable(
|
presence, receipts, (messages, token) = await make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
gather_results(
|
||||||
[
|
(
|
||||||
run_in_background(get_presence),
|
run_in_background(get_presence),
|
||||||
run_in_background(get_receipts),
|
run_in_background(get_receipts),
|
||||||
run_in_background(
|
run_in_background(
|
||||||
|
@ -464,7 +469,7 @@ class InitialSyncHandler:
|
||||||
limit=limit,
|
limit=limit,
|
||||||
end_token=now_token.room_key,
|
end_token=now_token.room_key,
|
||||||
),
|
),
|
||||||
],
|
),
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.internet.interfaces import IDelayedCall
|
from twisted.internet.interfaces import IDelayedCall
|
||||||
|
|
||||||
from synapse import event_auth
|
from synapse import event_auth
|
||||||
|
@ -57,7 +56,7 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
|
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
|
||||||
from synapse.util import json_decoder, json_encoder, log_failure
|
from synapse.util import json_decoder, json_encoder, log_failure
|
||||||
from synapse.util.async_helpers import Linearizer, unwrapFirstError
|
from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
@ -1168,9 +1167,9 @@ class EventCreationHandler:
|
||||||
|
|
||||||
# We now persist the event (and update the cache in parallel, since we
|
# We now persist the event (and update the cache in parallel, since we
|
||||||
# don't want to block on it).
|
# don't want to block on it).
|
||||||
result = await make_deferred_yieldable(
|
result, _ = await make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
gather_results(
|
||||||
[
|
(
|
||||||
run_in_background(
|
run_in_background(
|
||||||
self._persist_event,
|
self._persist_event,
|
||||||
requester=requester,
|
requester=requester,
|
||||||
|
@ -1182,12 +1181,12 @@ class EventCreationHandler:
|
||||||
run_in_background(
|
run_in_background(
|
||||||
self.cache_joined_hosts_for_event, event, context
|
self.cache_joined_hosts_for_event, event, context
|
||||||
).addErrback(log_failure, "cache_joined_hosts_for_event failed"),
|
).addErrback(log_failure, "cache_joined_hosts_for_event failed"),
|
||||||
],
|
),
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
)
|
)
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
return result[0]
|
return result
|
||||||
|
|
||||||
async def _persist_event(
|
async def _persist_event(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -25,6 +25,7 @@ from zope.interface import implementer
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||||
from twisted.internet.interfaces import (
|
from twisted.internet.interfaces import (
|
||||||
|
IProtocol,
|
||||||
IProtocolFactory,
|
IProtocolFactory,
|
||||||
IReactorCore,
|
IReactorCore,
|
||||||
IStreamClientEndpoint,
|
IStreamClientEndpoint,
|
||||||
|
@ -309,12 +310,14 @@ class MatrixHostnameEndpoint:
|
||||||
|
|
||||||
self._srv_resolver = srv_resolver
|
self._srv_resolver = srv_resolver
|
||||||
|
|
||||||
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
|
def connect(
|
||||||
|
self, protocol_factory: IProtocolFactory
|
||||||
|
) -> "defer.Deferred[IProtocol]":
|
||||||
"""Implements IStreamClientEndpoint interface"""
|
"""Implements IStreamClientEndpoint interface"""
|
||||||
|
|
||||||
return run_in_background(self._do_connect, protocol_factory)
|
return run_in_background(self._do_connect, protocol_factory)
|
||||||
|
|
||||||
async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
|
async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
|
||||||
first_exception = None
|
first_exception = None
|
||||||
|
|
||||||
server_list = await self._resolve_server()
|
server_list = await self._resolve_server()
|
||||||
|
|
|
@ -22,20 +22,33 @@ them.
|
||||||
|
|
||||||
See doc/log_contexts.rst for details on how this works.
|
See doc/log_contexts.rst for details on how this works.
|
||||||
"""
|
"""
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import typing
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
|
from types import TracebackType
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from twisted.internet import defer, threads
|
from twisted.internet import defer, threads
|
||||||
|
from twisted.python.threadpool import ThreadPool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.logging.scopecontextmanager import _LogContextScope
|
from synapse.logging.scopecontextmanager import _LogContextScope
|
||||||
|
from synapse.types import ISynapseReactor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -66,7 +79,7 @@ except Exception:
|
||||||
|
|
||||||
|
|
||||||
# a hook which can be set during testing to assert that we aren't abusing logcontexts.
|
# a hook which can be set during testing to assert that we aren't abusing logcontexts.
|
||||||
def logcontext_error(msg: str):
|
def logcontext_error(msg: str) -> None:
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@ -223,22 +236,19 @@ class _Sentinel:
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return "sentinel"
|
return "sentinel"
|
||||||
|
|
||||||
def copy_to(self, record):
|
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def start(self, rusage: "Optional[resource.struct_rusage]"):
|
def stop(self, rusage: "Optional[resource.struct_rusage]") -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def stop(self, rusage: "Optional[resource.struct_rusage]"):
|
def add_database_transaction(self, duration_sec: float) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def add_database_transaction(self, duration_sec):
|
def add_database_scheduled(self, sched_sec: float) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def add_database_scheduled(self, sched_sec):
|
def record_event_fetch(self, event_count: int) -> None:
|
||||||
pass
|
|
||||||
|
|
||||||
def record_event_fetch(self, event_count):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __bool__(self) -> Literal[False]:
|
def __bool__(self) -> Literal[False]:
|
||||||
|
@ -379,7 +389,12 @@ class LoggingContext:
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback) -> None:
|
def __exit__(
|
||||||
|
self,
|
||||||
|
type: Optional[Type[BaseException]],
|
||||||
|
value: Optional[BaseException],
|
||||||
|
traceback: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
"""Restore the logging context in thread local storage to the state it
|
"""Restore the logging context in thread local storage to the state it
|
||||||
was before this context was entered.
|
was before this context was entered.
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -399,17 +414,6 @@ class LoggingContext:
|
||||||
# recorded against the correct metrics.
|
# recorded against the correct metrics.
|
||||||
self.finished = True
|
self.finished = True
|
||||||
|
|
||||||
def copy_to(self, record) -> None:
|
|
||||||
"""Copy logging fields from this context to a log record or
|
|
||||||
another LoggingContext
|
|
||||||
"""
|
|
||||||
|
|
||||||
# we track the current request
|
|
||||||
record.request = self.request
|
|
||||||
|
|
||||||
# we also track the current scope:
|
|
||||||
record.scope = self.scope
|
|
||||||
|
|
||||||
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
|
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
|
||||||
"""
|
"""
|
||||||
Record that this logcontext is currently running.
|
Record that this logcontext is currently running.
|
||||||
|
@ -626,7 +630,12 @@ class PreserveLoggingContext:
|
||||||
def __enter__(self) -> None:
|
def __enter__(self) -> None:
|
||||||
self._old_context = set_current_context(self._new_context)
|
self._old_context = set_current_context(self._new_context)
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback) -> None:
|
def __exit__(
|
||||||
|
self,
|
||||||
|
type: Optional[Type[BaseException]],
|
||||||
|
value: Optional[BaseException],
|
||||||
|
traceback: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
context = set_current_context(self._old_context)
|
context = set_current_context(self._old_context)
|
||||||
|
|
||||||
if context != self._new_context:
|
if context != self._new_context:
|
||||||
|
@ -711,16 +720,61 @@ def nested_logging_context(suffix: str) -> LoggingContext:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def preserve_fn(f):
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def preserve_fn( # type: ignore[misc]
|
||||||
|
f: Callable[..., Awaitable[R]],
|
||||||
|
) -> Callable[..., "defer.Deferred[R]"]:
|
||||||
|
# The `type: ignore[misc]` above suppresses
|
||||||
|
# "Overloaded function signatures 1 and 2 overlap with incompatible return types"
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def preserve_fn(
|
||||||
|
f: Union[
|
||||||
|
Callable[..., R],
|
||||||
|
Callable[..., Awaitable[R]],
|
||||||
|
]
|
||||||
|
) -> Callable[..., "defer.Deferred[R]"]:
|
||||||
"""Function decorator which wraps the function with run_in_background"""
|
"""Function decorator which wraps the function with run_in_background"""
|
||||||
|
|
||||||
def g(*args, **kwargs):
|
def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]":
|
||||||
return run_in_background(f, *args, **kwargs)
|
return run_in_background(f, *args, **kwargs)
|
||||||
|
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def run_in_background(f, *args, **kwargs) -> defer.Deferred:
|
@overload
|
||||||
|
def run_in_background( # type: ignore[misc]
|
||||||
|
f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any
|
||||||
|
) -> "defer.Deferred[R]":
|
||||||
|
# The `type: ignore[misc]` above suppresses
|
||||||
|
# "Overloaded function signatures 1 and 2 overlap with incompatible return types"
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def run_in_background(
|
||||||
|
f: Callable[..., R], *args: Any, **kwargs: Any
|
||||||
|
) -> "defer.Deferred[R]":
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def run_in_background(
|
||||||
|
f: Union[
|
||||||
|
Callable[..., R],
|
||||||
|
Callable[..., Awaitable[R]],
|
||||||
|
],
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "defer.Deferred[R]":
|
||||||
"""Calls a function, ensuring that the current context is restored after
|
"""Calls a function, ensuring that the current context is restored after
|
||||||
return from the function, and that the sentinel context is set once the
|
return from the function, and that the sentinel context is set once the
|
||||||
deferred returned by the function completes.
|
deferred returned by the function completes.
|
||||||
|
@ -751,6 +805,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
|
||||||
# At this point we should have a Deferred, if not then f was a synchronous
|
# At this point we should have a Deferred, if not then f was a synchronous
|
||||||
# function, wrap it in a Deferred for consistency.
|
# function, wrap it in a Deferred for consistency.
|
||||||
if not isinstance(res, defer.Deferred):
|
if not isinstance(res, defer.Deferred):
|
||||||
|
# `res` is not a `Deferred` and not a `Coroutine`.
|
||||||
|
# There are no other types of `Awaitable`s we expect to encounter in Synapse.
|
||||||
|
assert not isinstance(res, Awaitable)
|
||||||
|
|
||||||
return defer.succeed(res)
|
return defer.succeed(res)
|
||||||
|
|
||||||
if res.called and not res.paused:
|
if res.called and not res.paused:
|
||||||
|
@ -778,13 +836,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def make_deferred_yieldable(deferred):
|
T = TypeVar("T")
|
||||||
"""Given a deferred (or coroutine), make it follow the Synapse logcontext
|
|
||||||
rules:
|
|
||||||
|
|
||||||
If the deferred has completed (or is not actually a Deferred), essentially
|
|
||||||
does nothing (just returns another completed deferred with the
|
def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
||||||
result/failure).
|
"""Given a deferred, make it follow the Synapse logcontext rules:
|
||||||
|
|
||||||
|
If the deferred has completed, essentially does nothing (just returns another
|
||||||
|
completed deferred with the result/failure).
|
||||||
|
|
||||||
If the deferred has not yet completed, resets the logcontext before
|
If the deferred has not yet completed, resets the logcontext before
|
||||||
returning a deferred. Then, when the deferred completes, restores the
|
returning a deferred. Then, when the deferred completes, restores the
|
||||||
|
@ -792,16 +851,6 @@ def make_deferred_yieldable(deferred):
|
||||||
|
|
||||||
(This is more-or-less the opposite operation to run_in_background.)
|
(This is more-or-less the opposite operation to run_in_background.)
|
||||||
"""
|
"""
|
||||||
if inspect.isawaitable(deferred):
|
|
||||||
# If we're given a coroutine we convert it to a deferred so that we
|
|
||||||
# run it and find out if it immediately finishes, it it does then we
|
|
||||||
# don't need to fiddle with log contexts at all and can return
|
|
||||||
# immediately.
|
|
||||||
deferred = defer.ensureDeferred(deferred)
|
|
||||||
|
|
||||||
if not isinstance(deferred, defer.Deferred):
|
|
||||||
return deferred
|
|
||||||
|
|
||||||
if deferred.called and not deferred.paused:
|
if deferred.called and not deferred.paused:
|
||||||
# it looks like this deferred is ready to run any callbacks we give it
|
# it looks like this deferred is ready to run any callbacks we give it
|
||||||
# immediately. We may as well optimise out the logcontext faffery.
|
# immediately. We may as well optimise out the logcontext faffery.
|
||||||
|
@ -823,7 +872,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def defer_to_thread(reactor, f, *args, **kwargs):
|
def defer_to_thread(
|
||||||
|
reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any
|
||||||
|
) -> "defer.Deferred[R]":
|
||||||
"""
|
"""
|
||||||
Calls the function `f` using a thread from the reactor's default threadpool and
|
Calls the function `f` using a thread from the reactor's default threadpool and
|
||||||
returns the result as a Deferred.
|
returns the result as a Deferred.
|
||||||
|
@ -855,7 +906,13 @@ def defer_to_thread(reactor, f, *args, **kwargs):
|
||||||
return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
|
return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
|
def defer_to_threadpool(
|
||||||
|
reactor: "ISynapseReactor",
|
||||||
|
threadpool: ThreadPool,
|
||||||
|
f: Callable[..., R],
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "defer.Deferred[R]":
|
||||||
"""
|
"""
|
||||||
A wrapper for twisted.internet.threads.deferToThreadpool, which handles
|
A wrapper for twisted.internet.threads.deferToThreadpool, which handles
|
||||||
logcontexts correctly.
|
logcontexts correctly.
|
||||||
|
@ -897,7 +954,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
|
||||||
assert isinstance(curr_context, LoggingContext)
|
assert isinstance(curr_context, LoggingContext)
|
||||||
parent_context = curr_context
|
parent_context = curr_context
|
||||||
|
|
||||||
def g():
|
def g() -> R:
|
||||||
with LoggingContext(str(curr_context), parent_context=parent_context):
|
with LoggingContext(str(curr_context), parent_context=parent_context):
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -30,9 +30,11 @@ from typing import (
|
||||||
Iterator,
|
Iterator,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
|
Tuple,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -234,6 +236,59 @@ def yieldable_gather_results(
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
|
||||||
|
T1 = TypeVar("T1")
|
||||||
|
T2 = TypeVar("T2")
|
||||||
|
T3 = TypeVar("T3")
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def gather_results(
|
||||||
|
deferredList: Tuple[()], consumeErrors: bool = ...
|
||||||
|
) -> "defer.Deferred[Tuple[()]]":
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def gather_results(
|
||||||
|
deferredList: Tuple["defer.Deferred[T1]"],
|
||||||
|
consumeErrors: bool = ...,
|
||||||
|
) -> "defer.Deferred[Tuple[T1]]":
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def gather_results(
|
||||||
|
deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"],
|
||||||
|
consumeErrors: bool = ...,
|
||||||
|
) -> "defer.Deferred[Tuple[T1, T2]]":
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def gather_results(
|
||||||
|
deferredList: Tuple[
|
||||||
|
"defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]"
|
||||||
|
],
|
||||||
|
consumeErrors: bool = ...,
|
||||||
|
) -> "defer.Deferred[Tuple[T1, T2, T3]]":
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def gather_results( # type: ignore[misc]
|
||||||
|
deferredList: Tuple["defer.Deferred[T1]", ...],
|
||||||
|
consumeErrors: bool = False,
|
||||||
|
) -> "defer.Deferred[Tuple[T1, ...]]":
|
||||||
|
"""Combines a tuple of `Deferred`s into a single `Deferred`.
|
||||||
|
|
||||||
|
Wraps `defer.gatherResults` to provide type annotations that support heterogenous
|
||||||
|
lists of `Deferred`s.
|
||||||
|
"""
|
||||||
|
# The `type: ignore[misc]` above suppresses
|
||||||
|
# "Overloaded function implementation cannot produce return type of signature 1/2/3"
|
||||||
|
deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors)
|
||||||
|
return deferred.addCallback(tuple)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
class _LinearizerEntry:
|
class _LinearizerEntry:
|
||||||
# The number of things executing.
|
# The number of things executing.
|
||||||
|
@ -352,7 +407,7 @@ class Linearizer:
|
||||||
|
|
||||||
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
|
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
|
||||||
|
|
||||||
new_defer = make_deferred_yieldable(defer.Deferred())
|
new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
|
||||||
entry.deferreds[new_defer] = 1
|
entry.deferreds[new_defer] = 1
|
||||||
|
|
||||||
def cb(_r: None) -> "defer.Deferred[None]":
|
def cb(_r: None) -> "defer.Deferred[None]":
|
||||||
|
|
|
@ -76,6 +76,7 @@ class CachedCall(Generic[TV]):
|
||||||
|
|
||||||
# Fire off the callable now if this is our first time
|
# Fire off the callable now if this is our first time
|
||||||
if not self._deferred:
|
if not self._deferred:
|
||||||
|
assert self._callable is not None
|
||||||
self._deferred = run_in_background(self._callable)
|
self._deferred = run_in_background(self._callable)
|
||||||
|
|
||||||
# we will never need the callable again, so make sure it can be GCed
|
# we will never need the callable again, so make sure it can be GCed
|
||||||
|
|
|
@ -142,6 +142,7 @@ class BackgroundFileConsumer:
|
||||||
|
|
||||||
def wait(self) -> "Deferred[None]":
|
def wait(self) -> "Deferred[None]":
|
||||||
"""Returns a deferred that resolves when finished writing to file"""
|
"""Returns a deferred that resolves when finished writing to file"""
|
||||||
|
assert self._finished_deferred is not None
|
||||||
return make_deferred_yieldable(self._finished_deferred)
|
return make_deferred_yieldable(self._finished_deferred)
|
||||||
|
|
||||||
def _resume_paused_producer(self) -> None:
|
def _resume_paused_producer(self) -> None:
|
||||||
|
|
|
@ -152,46 +152,11 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||||
# now it should be restored
|
# now it should be restored
|
||||||
self._check_test_key("one")
|
self._check_test_key("one")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_make_deferred_yieldable_on_non_deferred(self):
|
|
||||||
"""Check that make_deferred_yieldable does the right thing when its
|
|
||||||
argument isn't actually a deferred"""
|
|
||||||
|
|
||||||
with LoggingContext("one"):
|
|
||||||
d1 = make_deferred_yieldable("bum")
|
|
||||||
self._check_test_key("one")
|
|
||||||
|
|
||||||
r = yield d1
|
|
||||||
self.assertEqual(r, "bum")
|
|
||||||
self._check_test_key("one")
|
|
||||||
|
|
||||||
def test_nested_logging_context(self):
|
def test_nested_logging_context(self):
|
||||||
with LoggingContext("foo"):
|
with LoggingContext("foo"):
|
||||||
nested_context = nested_logging_context(suffix="bar")
|
nested_context = nested_logging_context(suffix="bar")
|
||||||
self.assertEqual(nested_context.name, "foo-bar")
|
self.assertEqual(nested_context.name, "foo-bar")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_make_deferred_yieldable_with_await(self):
|
|
||||||
# an async function which returns an incomplete coroutine, but doesn't
|
|
||||||
# follow the synapse rules.
|
|
||||||
|
|
||||||
async def blocking_function():
|
|
||||||
d = defer.Deferred()
|
|
||||||
reactor.callLater(0, d.callback, None)
|
|
||||||
await d
|
|
||||||
|
|
||||||
sentinel_context = current_context()
|
|
||||||
|
|
||||||
with LoggingContext("one"):
|
|
||||||
d1 = make_deferred_yieldable(blocking_function())
|
|
||||||
# make sure that the context was reset by make_deferred_yieldable
|
|
||||||
self.assertIs(current_context(), sentinel_context)
|
|
||||||
|
|
||||||
yield d1
|
|
||||||
|
|
||||||
# now it should be restored
|
|
||||||
self._check_test_key("one")
|
|
||||||
|
|
||||||
|
|
||||||
# a function which returns a deferred which has been "called", but
|
# a function which returns a deferred which has been "called", but
|
||||||
# which had a function which returned another incomplete deferred on
|
# which had a function which returned another incomplete deferred on
|
||||||
|
|
Loading…
Reference in New Issue