Make Distributor run its processes as a background process

This is more involved than it might otherwise be, because the current
implementation just drops its logcontexts and runs everything in the sentinel
context.

It turns out that we aren't actually using a bunch of the functionality here
(notably suppress_failures and the fact that Distributor.fire returns a
deferred), so the easiest way to fix this is actually by simplifying a bunch of
code.
pull/3556/head
Richard van der Hoff 2018-07-18 15:33:13 +01:00
parent 08436c556a
commit 8c69b735e3
2 changed files with 22 additions and 78 deletions

View File

@ -17,20 +17,18 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.util import unwrapFirstError from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def user_left_room(distributor, user, room_id): def user_left_room(distributor, user, room_id):
with PreserveLoggingContext(): distributor.fire("user_left_room", user=user, room_id=room_id)
distributor.fire("user_left_room", user=user, room_id=room_id)
def user_joined_room(distributor, user, room_id): def user_joined_room(distributor, user, room_id):
with PreserveLoggingContext(): distributor.fire("user_joined_room", user=user, room_id=room_id)
distributor.fire("user_joined_room", user=user, room_id=room_id)
class Distributor(object): class Distributor(object):
@ -44,9 +42,7 @@ class Distributor(object):
model will do for today. model will do for today.
""" """
def __init__(self, suppress_failures=True): def __init__(self):
self.suppress_failures = suppress_failures
self.signals = {} self.signals = {}
self.pre_registration = {} self.pre_registration = {}
@ -56,7 +52,6 @@ class Distributor(object):
self.signals[name] = Signal( self.signals[name] = Signal(
name, name,
suppress_failures=self.suppress_failures,
) )
if name in self.pre_registration: if name in self.pre_registration:
@ -82,7 +77,11 @@ class Distributor(object):
if name not in self.signals: if name not in self.signals:
raise KeyError("%r does not have a signal named %s" % (self, name)) raise KeyError("%r does not have a signal named %s" % (self, name))
return self.signals[name].fire(*args, **kwargs) run_as_background_process(
name,
self.signals[name].fire,
*args, **kwargs
)
class Signal(object): class Signal(object):
@ -95,9 +94,8 @@ class Signal(object):
method into all of the observers. method into all of the observers.
""" """
def __init__(self, name, suppress_failures): def __init__(self, name):
self.name = name self.name = name
self.suppress_failures = suppress_failures
self.observers = [] self.observers = []
def observe(self, observer): def observe(self, observer):
@ -107,7 +105,6 @@ class Signal(object):
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
@defer.inlineCallbacks
def fire(self, *args, **kwargs): def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
@ -125,22 +122,17 @@ class Signal(object):
failure.type, failure.type,
failure.value, failure.value,
failure.getTracebackObject())) failure.getTracebackObject()))
if not self.suppress_failures:
return failure
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
with PreserveLoggingContext(): deferreds = [
deferreds = [ run_in_background(do, o)
do(observer) for o in self.observers
for observer in self.observers ]
]
res = yield defer.gatherResults( return make_deferred_yieldable(defer.gatherResults(
deferreds, consumeErrors=True deferreds, consumeErrors=True,
).addErrback(unwrapFirstError) ))
defer.returnValue(res)
def __repr__(self): def __repr__(self):
return "<Signal name=%r>" % (self.name,) return "<Signal name=%r>" % (self.name,)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,8 +16,6 @@
from mock import Mock, patch from mock import Mock, patch
from twisted.internet import defer
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
from . import unittest from . import unittest
@ -27,38 +26,15 @@ class DistributorTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.dist = Distributor() self.dist = Distributor()
@defer.inlineCallbacks
def test_signal_dispatch(self): def test_signal_dispatch(self):
self.dist.declare("alert") self.dist.declare("alert")
observer = Mock() observer = Mock()
self.dist.observe("alert", observer) self.dist.observe("alert", observer)
d = self.dist.fire("alert", 1, 2, 3) self.dist.fire("alert", 1, 2, 3)
yield d
self.assertTrue(d.called)
observer.assert_called_with(1, 2, 3) observer.assert_called_with(1, 2, 3)
@defer.inlineCallbacks
def test_signal_dispatch_deferred(self):
self.dist.declare("whine")
d_inner = defer.Deferred()
def observer():
return d_inner
self.dist.observe("whine", observer)
d_outer = self.dist.fire("whine")
self.assertFalse(d_outer.called)
d_inner.callback(None)
yield d_outer
self.assertTrue(d_outer.called)
@defer.inlineCallbacks
def test_signal_catch(self): def test_signal_catch(self):
self.dist.declare("alarm") self.dist.declare("alarm")
@ -71,9 +47,7 @@ class DistributorTestCase(unittest.TestCase):
with patch( with patch(
"synapse.util.distributor.logger", spec=["warning"] "synapse.util.distributor.logger", spec=["warning"]
) as mock_logger: ) as mock_logger:
d = self.dist.fire("alarm", "Go") self.dist.fire("alarm", "Go")
yield d
self.assertTrue(d.called)
observers[0].assert_called_once_with("Go") observers[0].assert_called_once_with("Go")
observers[1].assert_called_once_with("Go") observers[1].assert_called_once_with("Go")
@ -83,34 +57,12 @@ class DistributorTestCase(unittest.TestCase):
mock_logger.warning.call_args[0][0], str mock_logger.warning.call_args[0][0], str
) )
@defer.inlineCallbacks
def test_signal_catch_no_suppress(self):
# Gut-wrenching
self.dist.suppress_failures = False
self.dist.declare("whail")
class MyException(Exception):
pass
@defer.inlineCallbacks
def observer():
raise MyException("Oopsie")
self.dist.observe("whail", observer)
d = self.dist.fire("whail")
yield self.assertFailure(d, MyException)
self.dist.suppress_failures = True
@defer.inlineCallbacks
def test_signal_prereg(self): def test_signal_prereg(self):
observer = Mock() observer = Mock()
self.dist.observe("flare", observer) self.dist.observe("flare", observer)
self.dist.declare("flare") self.dist.declare("flare")
yield self.dist.fire("flare", 4, 5) self.dist.fire("flare", 4, 5)
observer.assert_called_with(4, 5) observer.assert_called_with(4, 5)