diff --git a/demo/start.sh b/demo/start.sh index 5b3daef57f..b9cc14b9d2 100755 --- a/demo/start.sh +++ b/demo/start.sh @@ -31,6 +31,7 @@ for port in 8080 8081 8082; do #rm $DIR/etc/$port.config python -m synapse.app.homeserver \ --generate-config \ + --enable_registration \ -H "localhost:$https_port" \ --config-path "$DIR/etc/$port.config" \ diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 4911f0896b..24f15f3154 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -18,7 +18,9 @@ from twisted.web.http import HTTPClient from twisted.internet.protocol import Factory from twisted.internet import defer, reactor from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import ( + preserve_context_over_fn, preserve_context_over_deferred +) import simplejson as json import logging @@ -40,11 +42,14 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1): for i in range(5): try: - with PreserveLoggingContext(): - protocol = yield endpoint.connect(factory) - server_response, server_certificate = yield protocol.remote_key - defer.returnValue((server_response, server_certificate)) - return + protocol = yield preserve_context_over_fn( + endpoint.connect, factory + ) + server_response, server_certificate = yield preserve_context_over_deferred( + protocol.remote_key + ) + defer.returnValue((server_response, server_certificate)) + return except SynapseKeyClientError as e: logger.exception("Error getting key for %r" % (server_name,)) if e.status.startswith("4"): diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 21a763214b..5217d91aab 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -24,6 +24,8 @@ from synapse.crypto.event_signing import check_event_content_hash from synapse.api.errors import SynapseError +from synapse.util import unwrapFirstError + import logging @@ -94,7 +96,7 @@ class FederationBase(object): yield defer.gatherResults( [do(pdu) for pdu in pdus], consumeErrors=True - ) + ).addErrback(unwrapFirstError) defer.returnValue(signed_pdus) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2b46188c91..cd79e23f4b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -20,7 +20,6 @@ from .federation_base import FederationBase from .units import Transaction, Edu from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext from synapse.events import FrozenEvent import synapse.metrics @@ -123,29 +122,28 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) - with PreserveLoggingContext(): - results = [] + results = [] - for pdu in pdu_list: - d = self._handle_new_pdu(transaction.origin, pdu) + for pdu in pdu_list: + d = self._handle_new_pdu(transaction.origin, pdu) - try: - yield d - results.append({}) - except FederationError as e: - self.send_failure(e, transaction.origin) - results.append({"error": str(e)}) - except Exception as e: - results.append({"error": str(e)}) - logger.exception("Failed to handle PDU") + try: + yield d + results.append({}) + except FederationError as e: + self.send_failure(e, transaction.origin) + results.append({"error": str(e)}) + except Exception as e: + results.append({"error": str(e)}) + logger.exception("Failed to handle PDU") - if hasattr(transaction, "edus"): - for edu in [Edu(**x) for x in transaction.edus]: - self.received_edu( - transaction.origin, - edu.edu_type, - edu.content - ) + if hasattr(transaction, "edus"): + for edu in [Edu(**x) for x in transaction.edus]: + self.received_edu( + transaction.origin, + edu.edu_type, + edu.content + ) for failure in getattr(transaction, "pdu_failures", []): logger.info("Got failure %r", failure) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 4b3f4eadab..ddc5c21e7d 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -20,6 +20,8 @@ from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.api.constants import Membership, EventTypes from synapse.types import UserID +from synapse.util.logcontext import PreserveLoggingContext + import logging @@ -137,10 +139,11 @@ class BaseHandler(object): "Failed to get destination from event %s", s.event_id ) - # Don't block waiting on waking up all the listeners. - notify_d = self.notifier.on_new_room_event( - event, extra_users=extra_users - ) + with PreserveLoggingContext(): + # Don't block waiting on waking up all the listeners. + notify_d = self.notifier.on_new_room_event( + event, extra_users=extra_users + ) def log_failure(f): logger.warn( diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index f9f855213b..993d33ba47 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -15,7 +15,6 @@ from twisted.internet import defer -from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logutils import log_function from synapse.types import UserID from synapse.events.utils import serialize_event @@ -81,10 +80,9 @@ class EventStreamHandler(BaseHandler): # thundering herds on restart. timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) - with PreserveLoggingContext(): - events, tokens = yield self.notifier.get_events_for( - auth_user, room_ids, pagin_config, timeout - ) + events, tokens = yield self.notifier.get_events_for( + auth_user, room_ids, pagin_config, timeout + ) time_now = self.clock.time_msec() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1093112587..7d9906039e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -21,6 +21,8 @@ from synapse.api.errors import ( AuthError, FederationError, StoreError, CodeMessageException, SynapseError, ) from synapse.api.constants import EventTypes, Membership, RejectedReason +from synapse.util import unwrapFirstError +from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.util.frozenutils import unfreeze @@ -199,9 +201,10 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - d = self.notifier.on_new_room_event( - event, extra_users=extra_users - ) + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + event, extra_users=extra_users + ) def log_failure(f): logger.warn( @@ -566,9 +569,10 @@ class FederationHandler(BaseHandler): auth_events=auth_events, ) - d = self.notifier.on_new_room_event( - new_event, extra_users=[joinee] - ) + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + new_event, extra_users=[joinee] + ) def log_failure(f): logger.warn( @@ -647,9 +651,10 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - d = self.notifier.on_new_room_event( - event, extra_users=extra_users - ) + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + event, extra_users=extra_users + ) def log_failure(f): logger.warn( @@ -729,9 +734,10 @@ class FederationHandler(BaseHandler): ) target_user = UserID.from_string(event.state_key) - d = self.notifier.on_new_room_event( - event, extra_users=[target_user], - ) + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + event, extra_users=[target_user], + ) def log_failure(f): logger.warn( @@ -1056,7 +1062,7 @@ class FederationHandler(BaseHandler): if d in have_events and not have_events[d] ], consumeErrors=True - ) + ).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 1809a44a99..867fdbefb0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -20,6 +20,7 @@ from synapse.api.errors import RoomError, SynapseError from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator +from synapse.util import unwrapFirstError from synapse.util.logcontext import PreserveLoggingContext from synapse.types import UserID, RoomStreamToken @@ -313,7 +314,7 @@ class MessageHandler(BaseHandler): event.room_id ), ] - ) + ).addErrback(unwrapFirstError) start_token = now_token.copy_and_replace("room_key", token[0]) end_token = now_token.copy_and_replace("room_key", token[1]) @@ -338,7 +339,7 @@ class MessageHandler(BaseHandler): yield defer.gatherResults( [handle_room(e) for e in room_list], consumeErrors=True - ) + ).addErrback(unwrapFirstError) ret = { "rooms": rooms_ret, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 9e15610401..1edab05492 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -18,8 +18,8 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError from synapse.api.constants import PresenceState -from synapse.util.logutils import log_function from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logutils import log_function from synapse.types import UserID import synapse.metrics @@ -278,15 +278,14 @@ class PresenceHandler(BaseHandler): now_online = state["presence"] != PresenceState.OFFLINE was_polling = target_user in self._user_cachemap - with PreserveLoggingContext(): - if now_online and not was_polling: - self.start_polling_presence(target_user, state=state) - elif not now_online and was_polling: - self.stop_polling_presence(target_user) + if now_online and not was_polling: + self.start_polling_presence(target_user, state=state) + elif not now_online and was_polling: + self.stop_polling_presence(target_user) - # TODO(paul): perform a presence push as part of start/stop poll so - # we don't have to do this all the time - self.changed_presencelike_data(target_user, state) + # TODO(paul): perform a presence push as part of start/stop poll so + # we don't have to do this all the time + self.changed_presencelike_data(target_user, state) def bump_presence_active_time(self, user, now=None): if now is None: @@ -408,10 +407,10 @@ class PresenceHandler(BaseHandler): yield self.store.set_presence_list_accepted( observer_user.localpart, observed_user.to_string() ) - with PreserveLoggingContext(): - self.start_polling_presence( - observer_user, target_user=observed_user - ) + + self.start_polling_presence( + observer_user, target_user=observed_user + ) @defer.inlineCallbacks def deny_presence(self, observed_user, observer_user): @@ -430,10 +429,9 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - with PreserveLoggingContext(): - self.stop_polling_presence( - observer_user, target_user=observed_user - ) + self.stop_polling_presence( + observer_user, target_user=observed_user + ) @defer.inlineCallbacks def get_presence_list(self, observer_user, accepted=None): @@ -766,8 +764,7 @@ class PresenceHandler(BaseHandler): if not self._remote_sendmap[user]: del self._remote_sendmap[user] - with PreserveLoggingContext(): - yield defer.DeferredList(deferreds, consumeErrors=True) + yield defer.DeferredList(deferreds, consumeErrors=True) @defer.inlineCallbacks def push_update_to_local_and_remote(self, observed_user, statuscache, @@ -812,10 +809,11 @@ class PresenceHandler(BaseHandler): def push_update_to_clients(self, observed_user, users_to_push=[], room_ids=[], statuscache=None): - self.notifier.on_new_user_event( - users_to_push, - room_ids, - ) + with PreserveLoggingContext(): + self.notifier.on_new_user_event( + users_to_push, + room_ids, + ) class PresenceEventSource(object): diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index ee2732b848..71ff78ab23 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -17,8 +17,8 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.constants import EventTypes, Membership -from synapse.util.logcontext import PreserveLoggingContext from synapse.types import UserID +from synapse.util import unwrapFirstError from ._base import BaseHandler @@ -154,14 +154,13 @@ class ProfileHandler(BaseHandler): if not self.hs.is_mine(user): defer.returnValue(None) - with PreserveLoggingContext(): - (displayname, avatar_url) = yield defer.gatherResults( - [ - self.store.get_profile_displayname(user.localpart), - self.store.get_profile_avatar_url(user.localpart), - ], - consumeErrors=True - ) + (displayname, avatar_url) = yield defer.gatherResults( + [ + self.store.get_profile_displayname(user.localpart), + self.store.get_profile_avatar_url(user.localpart), + ], + consumeErrors=True + ).addErrback(unwrapFirstError) state["displayname"] = displayname state["avatar_url"] = avatar_url diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 29b6d52757..dac683616a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -21,7 +21,7 @@ from ._base import BaseHandler from synapse.types import UserID, RoomAlias, RoomID from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import StoreError, SynapseError -from synapse.util import stringutils +from synapse.util import stringutils, unwrapFirstError from synapse.util.async import run_on_reactor from synapse.events.utils import serialize_event @@ -537,7 +537,7 @@ class RoomListHandler(BaseHandler): for room in chunk ], consumeErrors=True, - ) + ).addErrback(unwrapFirstError) for i, room in enumerate(chunk): room["num_joined_members"] = len(results[i]) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index c0b2bd7db0..64fe51aa3e 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -18,6 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.errors import SynapseError, AuthError +from synapse.util.logcontext import PreserveLoggingContext from synapse.types import UserID import logging @@ -216,7 +217,8 @@ class TypingNotificationHandler(BaseHandler): self._latest_room_serial += 1 self._room_serials[room_id] = self._latest_room_serial - self.notifier.on_new_user_event(rooms=[room_id]) + with PreserveLoggingContext(): + self.notifier.on_new_user_event(rooms=[room_id]) class TypingNotificationEventSource(object): diff --git a/synapse/http/client.py b/synapse/http/client.py index e8a5dedab4..5b3cefb2dc 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.api.errors import CodeMessageException +from synapse.util.logcontext import preserve_context_over_fn from syutil.jsonutil import encode_canonical_json import synapse.metrics @@ -61,7 +62,10 @@ class SimpleHttpClient(object): # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.inc(method) - d = self.agent.request(method, *args, **kwargs) + d = preserve_context_over_fn( + self.agent.request, + method, *args, **kwargs + ) def _cb(response): incoming_responses_counter.inc(method, response.code) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 7fa295cad5..c99d237c73 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone from synapse.http.endpoint import matrix_federation_endpoint from synapse.util.async import sleep -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import preserve_context_over_fn import synapse.metrics from syutil.jsonutil import encode_canonical_json @@ -144,22 +144,22 @@ class MatrixFederationHttpClient(object): producer = body_callback(method, url_bytes, headers_dict) try: - with PreserveLoggingContext(): - request_deferred = self.agent.request( - destination, - endpoint, - method, - path_bytes, - param_bytes, - query_bytes, - Headers(headers_dict), - producer - ) + request_deferred = preserve_context_over_fn( + self.agent.request, + destination, + endpoint, + method, + path_bytes, + param_bytes, + query_bytes, + Headers(headers_dict), + producer + ) - response = yield self.clock.time_bound_deferred( - request_deferred, - time_out=60, - ) + response = yield self.clock.time_bound_deferred( + request_deferred, + time_out=60, + ) logger.debug("Got response to %s", method) break diff --git a/synapse/http/server.py b/synapse/http/server.py index 93ecbd7589..73efbff4f2 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -17,7 +17,7 @@ from synapse.api.errors import ( cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError ) -from synapse.util.logcontext import LoggingContext +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext import synapse.metrics from syutil.jsonutil import ( @@ -85,7 +85,9 @@ def request_handler(request_handler): "Received request: %s %s", request.method, request.path ) - yield request_handler(self, request) + d = request_handler(self, request) + with PreserveLoggingContext(): + yield d code = request.code except CodeMessageException as e: code = e.code diff --git a/synapse/notifier.py b/synapse/notifier.py index 78eb28e4b2..7282dfd7f3 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -16,7 +16,6 @@ from twisted.internet import defer from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext from synapse.types import StreamToken import synapse.metrics @@ -223,11 +222,10 @@ class Notifier(object): def eb(failure): logger.exception("Failed to notify listener", failure) - with PreserveLoggingContext(): - yield defer.DeferredList( - [notify(l).addErrback(eb) for l in listeners], - consumeErrors=True, - ) + yield defer.DeferredList( + [notify(l).addErrback(eb) for l in listeners], + consumeErrors=True, + ) @defer.inlineCallbacks @log_function @@ -298,11 +296,10 @@ class Notifier(object): failure.getTracebackObject()) ) - with PreserveLoggingContext(): - yield defer.DeferredList( - [notify(l).addErrback(eb) for l in listeners], - consumeErrors=True, - ) + yield defer.DeferredList( + [notify(l).addErrback(eb) for l in listeners], + consumeErrors=True, + ) @defer.inlineCallbacks def wait_for_events(self, user, rooms, filter, timeout, callback): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 9e348590ba..c9fe5a3555 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -18,7 +18,7 @@ from synapse.api.errors import StoreError from synapse.events import FrozenEvent from synapse.events.utils import prune_event from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext, LoggingContext +from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache import synapse.metrics @@ -420,10 +420,11 @@ class SQLBaseStore(object): self._txn_perf_counters.update(desc, start, end) sql_txn_timer.inc_by(duration, desc) - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) + result = yield preserve_context_over_fn( + self._db_pool.runWithConnection, + inner_func, *args, **kwargs + ) + for after_callback, after_args in after_callbacks: after_callback(*after_args) defer.returnValue(result) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 79109d0b19..c1a16b639a 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.logcontext import LoggingContext +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from twisted.internet import defer, reactor, task @@ -23,6 +23,12 @@ import logging logger = logging.getLogger(__name__) +def unwrapFirstError(failure): + # defer.gatherResults and DeferredLists wrap failures. + failure.trap(defer.FirstError) + return failure.value.subFailure + + class Clock(object): """A small utility that obtains current time-of-day so that time may be mocked during unit-tests. @@ -50,9 +56,12 @@ class Clock(object): current_context = LoggingContext.current_context() def wrapped_callback(): - LoggingContext.thread_local.current_context = current_context - callback() - return reactor.callLater(delay, wrapped_callback) + with PreserveLoggingContext(): + LoggingContext.thread_local.current_context = current_context + callback() + + with PreserveLoggingContext(): + return reactor.callLater(delay, wrapped_callback) def cancel_call_later(self, timer): timer.cancel() diff --git a/synapse/util/async.py b/synapse/util/async.py index d8febdb90c..f78395a431 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -16,15 +16,13 @@ from twisted.internet import defer, reactor -from .logcontext import PreserveLoggingContext +from .logcontext import preserve_context_over_deferred -@defer.inlineCallbacks def sleep(seconds): d = defer.Deferred() reactor.callLater(seconds, d.callback, seconds) - with PreserveLoggingContext(): - yield d + return preserve_context_over_deferred(d) def run_on_reactor(): diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 9d9c350397..064c4a7a1e 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -13,10 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.logcontext import PreserveLoggingContext - from twisted.internet import defer +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_context_over_deferred, +) + +from synapse.util import unwrapFirstError + import logging @@ -93,7 +97,6 @@ class Signal(object): Each observer callable may return a Deferred.""" self.observers.append(observer) - @defer.inlineCallbacks def fire(self, *args, **kwargs): """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is @@ -101,24 +104,28 @@ class Signal(object): Returns a Deferred that will complete when all the observers have completed.""" - with PreserveLoggingContext(): - deferreds = [] - for observer in self.observers: - d = defer.maybeDeferred(observer, *args, **kwargs) - def eb(failure): - logger.warning( - "%s signal observer %s failed: %r", - self.name, observer, failure, - exc_info=( - failure.type, - failure.value, - failure.getTracebackObject())) - if not self.suppress_failures: - failure.raiseException() - deferreds.append(d.addErrback(eb)) - results = [] - for deferred in deferreds: - result = yield deferred - results.append(result) - defer.returnValue(results) + def do(observer): + def eb(failure): + logger.warning( + "%s signal observer %s failed: %r", + self.name, observer, failure, + exc_info=( + failure.type, + failure.value, + failure.getTracebackObject())) + if not self.suppress_failures: + return failure + return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) + + with PreserveLoggingContext(): + deferreds = [ + do(observer) + for observer in self.observers + ] + + d = defer.gatherResults(deferreds, consumeErrors=True) + + d.addErrback(unwrapFirstError) + + return preserve_context_over_deferred(d) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index da7872e95d..a92d518b43 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + import threading import logging @@ -129,3 +131,53 @@ class PreserveLoggingContext(object): def __exit__(self, type, value, traceback): """Restores the current logging context""" LoggingContext.thread_local.current_context = self.current_context + + if self.current_context is not LoggingContext.sentinel: + if self.current_context.parent_context is None: + logger.warn( + "Restoring dead context: %s", + self.current_context, + ) + + +def preserve_context_over_fn(fn, *args, **kwargs): + """Takes a function and invokes it with the given arguments, but removes + and restores the current logging context while doing so. + + If the result is a deferred, call preserve_context_over_deferred before + returning it. + """ + with PreserveLoggingContext(): + res = fn(*args, **kwargs) + + if isinstance(res, defer.Deferred): + return preserve_context_over_deferred(res) + else: + return res + + +def preserve_context_over_deferred(deferred): + """Given a deferred wrap it such that any callbacks added later to it will + be invoked with the current context. + """ + d = defer.Deferred() + + current_context = LoggingContext.current_context() + + def cb(res): + with PreserveLoggingContext(): + LoggingContext.thread_local.current_context = current_context + res = d.callback(res) + return res + + def eb(failure): + with PreserveLoggingContext(): + LoggingContext.thread_local.current_context = current_context + res = d.errback(failure) + return res + + if deferred.called: + return deferred + + deferred.addCallbacks(cb, eb) + return d