136 lines
5.2 KiB
Python
136 lines
5.2 KiB
Python
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import enum
|
|
from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
|
|
|
|
from twisted.internet.defer import Deferred
|
|
from twisted.python.failure import Failure
|
|
|
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
|
|
|
TV = TypeVar("TV")
|
|
|
|
|
|
class _Sentinel(enum.Enum):
|
|
sentinel = object()
|
|
|
|
|
|
class CachedCall(Generic[TV]):
|
|
"""A wrapper for asynchronous calls whose results should be shared
|
|
|
|
This is useful for wrapping asynchronous functions, where there might be multiple
|
|
callers, but we only want to call the underlying function once (and have the result
|
|
returned to all callers).
|
|
|
|
Similar results can be achieved via a lock of some form, but that typically requires
|
|
more boilerplate (and ends up being less efficient).
|
|
|
|
Correctly handles Synapse logcontexts (logs and resource usage for the underlying
|
|
function are logged against the logcontext which is active when get() is first
|
|
called).
|
|
|
|
Example usage:
|
|
|
|
_cached_val = CachedCall(_load_prop)
|
|
|
|
async def handle_request() -> X:
|
|
# We can call this multiple times, but it will result in a single call to
|
|
# _load_prop().
|
|
return await _cached_val.get()
|
|
|
|
async def _load_prop() -> X:
|
|
await difficult_operation()
|
|
|
|
|
|
The implementation is deliberately single-shot (ie, once the call is initiated,
|
|
there is no way to ask for it to be run). This keeps the implementation and
|
|
semantics simple. If you want to make a new call, simply replace the whole
|
|
CachedCall object.
|
|
"""
|
|
|
|
__slots__ = ["_callable", "_deferred", "_result"]
|
|
|
|
def __init__(self, f: Callable[[], Awaitable[TV]]):
|
|
"""
|
|
Args:
|
|
f: The underlying function. Only one call to this function will be alive
|
|
at once (per instance of CachedCall)
|
|
"""
|
|
self._callable: Optional[Callable[[], Awaitable[TV]]] = f
|
|
self._deferred: Optional[Deferred] = None
|
|
self._result: Union[_Sentinel, TV, Failure] = _Sentinel.sentinel
|
|
|
|
async def get(self) -> TV:
|
|
"""Kick off the call if necessary, and return the result"""
|
|
|
|
# Fire off the callable now if this is our first time
|
|
if not self._deferred:
|
|
self._deferred = run_in_background(self._callable)
|
|
|
|
# we will never need the callable again, so make sure it can be GCed
|
|
self._callable = None
|
|
|
|
# once the deferred completes, store the result. We cannot simply leave the
|
|
# result in the deferred, since `awaiting` a deferred destroys its result.
|
|
# (Also, if it's a Failure, GCing the deferred would log a critical error
|
|
# about unhandled Failures)
|
|
def got_result(r: Union[TV, Failure]) -> None:
|
|
self._result = r
|
|
|
|
self._deferred.addBoth(got_result)
|
|
|
|
# TODO: consider cancellation semantics. Currently, if the call to get()
|
|
# is cancelled, the underlying call will continue (and any future calls
|
|
# will get the result/exception), which I think is *probably* ok, modulo
|
|
# the fact the underlying call may be logged to a cancelled logcontext,
|
|
# and any eventual exception may not be reported.
|
|
|
|
# we can now await the deferred, and once it completes, return the result.
|
|
if isinstance(self._result, _Sentinel):
|
|
await make_deferred_yieldable(self._deferred)
|
|
assert not isinstance(self._result, _Sentinel)
|
|
|
|
if isinstance(self._result, Failure):
|
|
self._result.raiseException()
|
|
raise AssertionError("unexpected return from Failure.raiseException")
|
|
|
|
return self._result
|
|
|
|
|
|
class RetryOnExceptionCachedCall(Generic[TV]):
|
|
"""A wrapper around CachedCall which will retry the call if an exception is thrown
|
|
|
|
This is used in much the same way as CachedCall, but adds some extra functionality
|
|
so that if the underlying function throws an exception, then the next call to get()
|
|
will initiate another call to the underlying function. (Any calls to get() which
|
|
are already pending will raise the exception.)
|
|
"""
|
|
|
|
slots = ["_cachedcall"]
|
|
|
|
def __init__(self, f: Callable[[], Awaitable[TV]]):
|
|
async def _wrapper() -> TV:
|
|
try:
|
|
return await f()
|
|
except Exception:
|
|
# the call raised an exception: replace the underlying CachedCall to
|
|
# trigger another call next time get() is called
|
|
self._cachedcall = CachedCall(_wrapper)
|
|
raise
|
|
|
|
self._cachedcall = CachedCall(_wrapper)
|
|
|
|
async def get(self) -> TV:
|
|
return await self._cachedcall.get()
|