Correctly await on_logged_out callbacks (#11786)

pull/11793/head
Brendan Abolivier 2022-01-20 19:19:40 +01:00 committed by GitHub
parent d09099642e
commit bfe6d5553a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 2 deletions

1
changelog.d/11786.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in Synapse 1.46.0 that prevented `on_logged_out` module callbacks from being correctly awaited by Synapse.

View File

@ -2281,7 +2281,7 @@ class PasswordAuthProvider:
# call all of the on_logged_out callbacks
for callback in self.on_logged_out_callbacks:
try:
callback(user_id, device_id, access_token)
await callback(user_id, device_id, access_token)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue

View File

@ -22,7 +22,7 @@ from twisted.internet import defer
import synapse
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login
from synapse.rest.client import devices, login, logout
from synapse.types import JsonDict
from tests import unittest
@ -155,6 +155,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets,
login.register_servlets,
devices.register_servlets,
logout.register_servlets,
]
def setUp(self):
@ -719,6 +720,31 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
def test_on_logged_out(self):
"""Tests that the on_logged_out callback is called when the user logs out."""
self.register_user("rin", "password")
tok = self.login("rin", "password")
self.called = False
async def on_logged_out(user_id, device_id, access_token):
self.called = True
on_logged_out = Mock(side_effect=on_logged_out)
self.hs.get_password_auth_provider().on_logged_out_callbacks.append(
on_logged_out
)
channel = self.make_request(
"POST",
"/_matrix/client/v3/logout",
{},
access_token=tok,
)
self.assertEqual(channel.code, 200)
on_logged_out.assert_called_once()
self.assertTrue(self.called)
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)