Appease mypy

pull/6127/head
Erik Johnston 2019-10-10 12:15:17 +01:00
parent 791a8c559b
commit 941edad583
1 changed files with 18 additions and 13 deletions

View File

@ -18,11 +18,17 @@ from __future__ import print_function
import functools
import sys
from typing import List, Callable, Any
from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
# Tracks if we've already patched inlineCallbacks
_already_patched = False
def do_patch():
"""
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
@ -30,16 +36,18 @@ def do_patch():
from synapse.logging.context import LoggingContext
global _already_patched
orig_inline_callbacks = defer.inlineCallbacks
if hasattr(orig_inline_callbacks, "patched_by_synapse"):
if _already_patched:
return
def new_inline_callbacks(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
start_context = LoggingContext.current_context()
changes = []
orig = orig_inline_callbacks(_check_yield_points(f, changes, start_context))
changes: List[str] = []
orig = orig_inline_callbacks(_check_yield_points(f, changes))
try:
res = orig(*args, **kwargs)
@ -101,10 +109,10 @@ def do_patch():
return wrapped
defer.inlineCallbacks = new_inline_callbacks
new_inline_callbacks.patched_by_synapse = True
_already_patched = True
def _check_yield_points(f, changes, start_context):
def _check_yield_points(f: Callable, changes: List[str]):
"""Wraps a generator that is about to be passed to defer.inlineCallbacks
checking that after every yield the log contexts are correct.
@ -114,9 +122,8 @@ def _check_yield_points(f, changes, start_context):
Args:
f: generator function to wrap
changes (list[str]): A list of strings detailing how the contexts
changes: A list of strings detailing how the contexts
changed within a function.
start_context (LoggingContext): The initial context we're expecting
Returns:
function
@ -126,13 +133,13 @@ def _check_yield_points(f, changes, start_context):
@functools.wraps(f)
def check_yield_points_inner(*args, **kwargs):
expected_context = start_context
gen = f(*args, **kwargs)
last_yield_line_no = gen.gi_frame.f_lineno
result = None
result: Any = None
while True:
expected_context = LoggingContext.current_context()
try:
isFailure = isinstance(result, Failure)
if isFailure:
@ -200,7 +207,7 @@ def _check_yield_points(f, changes, start_context):
"%s changed context from %s to %s, happened between lines %d and %d in %s"
% (
frame.f_code.co_name,
start_context,
expected_context,
LoggingContext.current_context(),
last_yield_line_no,
frame.f_lineno,
@ -209,8 +216,6 @@ def _check_yield_points(f, changes, start_context):
)
changes.append(err)
expected_context = LoggingContext.current_context()
last_yield_line_no = frame.f_lineno
return check_yield_points_inner