136 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			136 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
# Copyright 2021 The Matrix.org Foundation C.I.C.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
import enum
 | 
						|
from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
 | 
						|
 | 
						|
from twisted.internet.defer import Deferred
 | 
						|
from twisted.python.failure import Failure
 | 
						|
 | 
						|
from synapse.logging.context import make_deferred_yieldable, run_in_background
 | 
						|
 | 
						|
TV = TypeVar("TV")
 | 
						|
 | 
						|
 | 
						|
class _Sentinel(enum.Enum):
 | 
						|
    sentinel = object()
 | 
						|
 | 
						|
 | 
						|
class CachedCall(Generic[TV]):
 | 
						|
    """A wrapper for asynchronous calls whose results should be shared
 | 
						|
 | 
						|
    This is useful for wrapping asynchronous functions, where there might be multiple
 | 
						|
    callers, but we only want to call the underlying function once (and have the result
 | 
						|
    returned to all callers).
 | 
						|
 | 
						|
    Similar results can be achieved via a lock of some form, but that typically requires
 | 
						|
    more boilerplate (and ends up being less efficient).
 | 
						|
 | 
						|
    Correctly handles Synapse logcontexts (logs and resource usage for the underlying
 | 
						|
    function are logged against the logcontext which is active when get() is first
 | 
						|
    called).
 | 
						|
 | 
						|
    Example usage:
 | 
						|
 | 
						|
        _cached_val = CachedCall(_load_prop)
 | 
						|
 | 
						|
        async def handle_request() -> X:
 | 
						|
            # We can call this multiple times, but it will result in a single call to
 | 
						|
            # _load_prop().
 | 
						|
            return await _cached_val.get()
 | 
						|
 | 
						|
        async def _load_prop() -> X:
 | 
						|
            await difficult_operation()
 | 
						|
 | 
						|
 | 
						|
    The implementation is deliberately single-shot (ie, once the call is initiated,
 | 
						|
    there is no way to ask for it to be run). This keeps the implementation and
 | 
						|
    semantics simple. If you want to make a new call, simply replace the whole
 | 
						|
    CachedCall object.
 | 
						|
    """
 | 
						|
 | 
						|
    __slots__ = ["_callable", "_deferred", "_result"]
 | 
						|
 | 
						|
    def __init__(self, f: Callable[[], Awaitable[TV]]):
 | 
						|
        """
 | 
						|
        Args:
 | 
						|
            f: The underlying function. Only one call to this function will be alive
 | 
						|
                at once (per instance of CachedCall)
 | 
						|
        """
 | 
						|
        self._callable: Optional[Callable[[], Awaitable[TV]]] = f
 | 
						|
        self._deferred: Optional[Deferred] = None
 | 
						|
        self._result: Union[_Sentinel, TV, Failure] = _Sentinel.sentinel
 | 
						|
 | 
						|
    async def get(self) -> TV:
 | 
						|
        """Kick off the call if necessary, and return the result"""
 | 
						|
 | 
						|
        # Fire off the callable now if this is our first time
 | 
						|
        if not self._deferred:
 | 
						|
            self._deferred = run_in_background(self._callable)
 | 
						|
 | 
						|
            # we will never need the callable again, so make sure it can be GCed
 | 
						|
            self._callable = None
 | 
						|
 | 
						|
            # once the deferred completes, store the result. We cannot simply leave the
 | 
						|
            # result in the deferred, since `awaiting` a deferred destroys its result.
 | 
						|
            # (Also, if it's a Failure, GCing the deferred would log a critical error
 | 
						|
            # about unhandled Failures)
 | 
						|
            def got_result(r):
 | 
						|
                self._result = r
 | 
						|
 | 
						|
            self._deferred.addBoth(got_result)
 | 
						|
 | 
						|
        # TODO: consider cancellation semantics. Currently, if the call to get()
 | 
						|
        #    is cancelled, the underlying call will continue (and any future calls
 | 
						|
        #    will get the result/exception), which I think is *probably* ok, modulo
 | 
						|
        #    the fact the underlying call may be logged to a cancelled logcontext,
 | 
						|
        #    and any eventual exception may not be reported.
 | 
						|
 | 
						|
        # we can now await the deferred, and once it completes, return the result.
 | 
						|
        if isinstance(self._result, _Sentinel):
 | 
						|
            await make_deferred_yieldable(self._deferred)
 | 
						|
            assert not isinstance(self._result, _Sentinel)
 | 
						|
 | 
						|
        if isinstance(self._result, Failure):
 | 
						|
            self._result.raiseException()
 | 
						|
            raise AssertionError("unexpected return from Failure.raiseException")
 | 
						|
 | 
						|
        return self._result
 | 
						|
 | 
						|
 | 
						|
class RetryOnExceptionCachedCall(Generic[TV]):
 | 
						|
    """A wrapper around CachedCall which will retry the call if an exception is thrown
 | 
						|
 | 
						|
    This is used in much the same way as CachedCall, but adds some extra functionality
 | 
						|
    so that if the underlying function throws an exception, then the next call to get()
 | 
						|
    will initiate another call to the underlying function. (Any calls to get() which
 | 
						|
    are already pending will raise the exception.)
 | 
						|
    """
 | 
						|
 | 
						|
    slots = ["_cachedcall"]
 | 
						|
 | 
						|
    def __init__(self, f: Callable[[], Awaitable[TV]]):
 | 
						|
        async def _wrapper() -> TV:
 | 
						|
            try:
 | 
						|
                return await f()
 | 
						|
            except Exception:
 | 
						|
                # the call raised an exception: replace the underlying CachedCall to
 | 
						|
                # trigger another call next time get() is called
 | 
						|
                self._cachedcall = CachedCall(_wrapper)
 | 
						|
                raise
 | 
						|
 | 
						|
        self._cachedcall = CachedCall(_wrapper)
 | 
						|
 | 
						|
    async def get(self) -> TV:
 | 
						|
        return await self._cachedcall.get()
 |