Trace functions which return `Awaitable` (#15650)
parent
4e6390cb10
commit
8bfded81f3
|
@ -0,0 +1 @@
|
||||||
|
Add support for tracing functions which return `Awaitable`s.
|
|
@ -171,6 +171,7 @@ from functools import wraps
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
ContextManager,
|
ContextManager,
|
||||||
|
@ -903,6 +904,7 @@ def _custom_sync_async_decorator(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(func):
|
if inspect.iscoroutinefunction(func):
|
||||||
|
# For this branch, we handle async functions like `async def func() -> RInner`.
|
||||||
# In this branch, R = Awaitable[RInner], for some other type RInner
|
# In this branch, R = Awaitable[RInner], for some other type RInner
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def _wrapper(
|
async def _wrapper(
|
||||||
|
@ -914,15 +916,16 @@ def _custom_sync_async_decorator(
|
||||||
return await func(*args, **kwargs) # type: ignore[misc]
|
return await func(*args, **kwargs) # type: ignore[misc]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# The other case here handles both sync functions and those
|
# The other case here handles sync functions including those decorated with
|
||||||
# decorated with inlineDeferred.
|
# `@defer.inlineCallbacks` or that return a `Deferred` or other `Awaitable`.
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
|
||||||
scope = wrapping_logic(func, *args, **kwargs)
|
scope = wrapping_logic(func, *args, **kwargs)
|
||||||
scope.__enter__()
|
scope.__enter__()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
if isinstance(result, defer.Deferred):
|
if isinstance(result, defer.Deferred):
|
||||||
|
|
||||||
def call_back(result: R) -> R:
|
def call_back(result: R) -> R:
|
||||||
|
@ -930,20 +933,32 @@ def _custom_sync_async_decorator(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def err_back(result: R) -> R:
|
def err_back(result: R) -> R:
|
||||||
|
# TODO: Pass the error details into `scope.__exit__(...)` for
|
||||||
|
# consistency with the other paths.
|
||||||
scope.__exit__(None, None, None)
|
scope.__exit__(None, None, None)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result.addCallbacks(call_back, err_back)
|
result.addCallbacks(call_back, err_back)
|
||||||
|
|
||||||
else:
|
elif inspect.isawaitable(result):
|
||||||
if inspect.isawaitable(result):
|
|
||||||
logger.error(
|
|
||||||
"@trace may not have wrapped %s correctly! "
|
|
||||||
"The function is not async but returned a %s.",
|
|
||||||
func.__qualname__,
|
|
||||||
type(result).__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
async def wrap_awaitable() -> Any:
|
||||||
|
try:
|
||||||
|
assert isinstance(result, Awaitable)
|
||||||
|
awaited_result = await result
|
||||||
|
scope.__exit__(None, None, None)
|
||||||
|
return awaited_result
|
||||||
|
except Exception as e:
|
||||||
|
scope.__exit__(type(e), None, e.__traceback__)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# The original method returned an awaitable, eg. a coroutine, so we
|
||||||
|
# create another awaitable wrapping it that calls
|
||||||
|
# `scope.__exit__(...)`.
|
||||||
|
return wrap_awaitable()
|
||||||
|
else:
|
||||||
|
# Just a simple sync function so we can just exit the scope and
|
||||||
|
# return the result without any fuss.
|
||||||
scope.__exit__(None, None, None)
|
scope.__exit__(None, None, None)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import cast
|
from typing import Awaitable, cast
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.test.proto_helpers import MemoryReactorClock
|
from twisted.test.proto_helpers import MemoryReactorClock
|
||||||
|
@ -227,8 +227,6 @@ class LogContextScopeManagerTestCase(TestCase):
|
||||||
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
|
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
|
||||||
with functions that return deferreds
|
with functions that return deferreds
|
||||||
"""
|
"""
|
||||||
reactor = MemoryReactorClock()
|
|
||||||
|
|
||||||
with LoggingContext("root context"):
|
with LoggingContext("root context"):
|
||||||
|
|
||||||
@trace_with_opname("fixture_deferred_func", tracer=self._tracer)
|
@trace_with_opname("fixture_deferred_func", tracer=self._tracer)
|
||||||
|
@ -240,9 +238,6 @@ class LogContextScopeManagerTestCase(TestCase):
|
||||||
|
|
||||||
result_d1 = fixture_deferred_func()
|
result_d1 = fixture_deferred_func()
|
||||||
|
|
||||||
# let the tasks complete
|
|
||||||
reactor.pump((2,) * 8)
|
|
||||||
|
|
||||||
self.assertEqual(self.successResultOf(result_d1), "foo")
|
self.assertEqual(self.successResultOf(result_d1), "foo")
|
||||||
|
|
||||||
# the span should have been reported
|
# the span should have been reported
|
||||||
|
@ -256,8 +251,6 @@ class LogContextScopeManagerTestCase(TestCase):
|
||||||
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
|
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
|
||||||
with async functions
|
with async functions
|
||||||
"""
|
"""
|
||||||
reactor = MemoryReactorClock()
|
|
||||||
|
|
||||||
with LoggingContext("root context"):
|
with LoggingContext("root context"):
|
||||||
|
|
||||||
@trace_with_opname("fixture_async_func", tracer=self._tracer)
|
@trace_with_opname("fixture_async_func", tracer=self._tracer)
|
||||||
|
@ -267,9 +260,6 @@ class LogContextScopeManagerTestCase(TestCase):
|
||||||
|
|
||||||
d1 = defer.ensureDeferred(fixture_async_func())
|
d1 = defer.ensureDeferred(fixture_async_func())
|
||||||
|
|
||||||
# let the tasks complete
|
|
||||||
reactor.pump((2,) * 8)
|
|
||||||
|
|
||||||
self.assertEqual(self.successResultOf(d1), "foo")
|
self.assertEqual(self.successResultOf(d1), "foo")
|
||||||
|
|
||||||
# the span should have been reported
|
# the span should have been reported
|
||||||
|
@ -277,3 +267,34 @@ class LogContextScopeManagerTestCase(TestCase):
|
||||||
[span.operation_name for span in self._reporter.get_spans()],
|
[span.operation_name for span in self._reporter.get_spans()],
|
||||||
["fixture_async_func"],
|
["fixture_async_func"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_trace_decorator_awaitable_return(self) -> None:
|
||||||
|
"""
|
||||||
|
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
|
||||||
|
with functions that return an awaitable (e.g. a coroutine)
|
||||||
|
"""
|
||||||
|
with LoggingContext("root context"):
|
||||||
|
# Something we can return without `await` to get a coroutine
|
||||||
|
async def fixture_async_func() -> str:
|
||||||
|
return "foo"
|
||||||
|
|
||||||
|
# The actual kind of function we want to test that returns an awaitable
|
||||||
|
@trace_with_opname("fixture_awaitable_return_func", tracer=self._tracer)
|
||||||
|
@tag_args
|
||||||
|
def fixture_awaitable_return_func() -> Awaitable[str]:
|
||||||
|
return fixture_async_func()
|
||||||
|
|
||||||
|
# Something we can run with `defer.ensureDeferred(runner())` and pump the
|
||||||
|
# whole async tasks through to completion.
|
||||||
|
async def runner() -> str:
|
||||||
|
return await fixture_awaitable_return_func()
|
||||||
|
|
||||||
|
d1 = defer.ensureDeferred(runner())
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(d1), "foo")
|
||||||
|
|
||||||
|
# the span should have been reported
|
||||||
|
self.assertEqual(
|
||||||
|
[span.operation_name for span in self._reporter.get_spans()],
|
||||||
|
["fixture_awaitable_return_func"],
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue