Allow modules to set a display name on registration (#12009)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>pull/12020/head
parent
da0e9f8efd
commit
707049c6ff
|
@ -0,0 +1 @@
|
||||||
|
Enable modules to set a custom display name when registering a user.
|
|
@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`.
|
||||||
If multiple modules implement this callback, they will be considered in order. If a
|
If multiple modules implement this callback, they will be considered in order. If a
|
||||||
callback returns `None`, Synapse falls through to the next one. The value of the first
|
callback returns `None`, Synapse falls through to the next one. The value of the first
|
||||||
callback that does not return `None` will be used. If this happens, Synapse will not call
|
callback that does not return `None` will be used. If this happens, Synapse will not call
|
||||||
any of the subsequent implementations of this callback. If every callback return `None`,
|
any of the subsequent implementations of this callback. If every callback returns `None`,
|
||||||
the authentication is denied.
|
the authentication is denied.
|
||||||
|
|
||||||
### `on_logged_out`
|
### `on_logged_out`
|
||||||
|
@ -162,10 +162,38 @@ return `None`.
|
||||||
If multiple modules implement this callback, they will be considered in order. If a
|
If multiple modules implement this callback, they will be considered in order. If a
|
||||||
callback returns `None`, Synapse falls through to the next one. The value of the first
|
callback returns `None`, Synapse falls through to the next one. The value of the first
|
||||||
callback that does not return `None` will be used. If this happens, Synapse will not call
|
callback that does not return `None` will be used. If this happens, Synapse will not call
|
||||||
any of the subsequent implementations of this callback. If every callback return `None`,
|
any of the subsequent implementations of this callback. If every callback returns `None`,
|
||||||
the username provided by the user is used, if any (otherwise one is automatically
|
the username provided by the user is used, if any (otherwise one is automatically
|
||||||
generated).
|
generated).
|
||||||
|
|
||||||
|
### `get_displayname_for_registration`
|
||||||
|
|
||||||
|
_First introduced in Synapse v1.54.0_
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def get_displayname_for_registration(
|
||||||
|
uia_results: Dict[str, Any],
|
||||||
|
params: Dict[str, Any],
|
||||||
|
) -> Optional[str]
|
||||||
|
```
|
||||||
|
|
||||||
|
Called when registering a new user. The module can return a display name to set for the
|
||||||
|
user being registered by returning it as a string, or `None` if it doesn't wish to force a
|
||||||
|
display name for this user.
|
||||||
|
|
||||||
|
This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api)
|
||||||
|
has been completed by the user. It is not called when registering a user via SSO. It is
|
||||||
|
passed two dictionaries, which include the information that the user has provided during
|
||||||
|
the registration process. These dictionaries are identical to the ones passed to
|
||||||
|
[`get_username_for_registration`](#get_username_for_registration), so refer to the
|
||||||
|
documentation of this callback for more information about them.
|
||||||
|
|
||||||
|
If multiple modules implement this callback, they will be considered in order. If a
|
||||||
|
callback returns `None`, Synapse falls through to the next one. The value of the first
|
||||||
|
callback that does not return `None` will be used. If this happens, Synapse will not call
|
||||||
|
any of the subsequent implementations of this callback. If every callback returns `None`,
|
||||||
|
the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`).
|
||||||
|
|
||||||
## `is_3pid_allowed`
|
## `is_3pid_allowed`
|
||||||
|
|
||||||
_First introduced in Synapse v1.53.0_
|
_First introduced in Synapse v1.53.0_
|
||||||
|
@ -196,7 +224,6 @@ The example module below implements authentication checkers for two different lo
|
||||||
- Expects a `password` field to be sent to `/login`
|
- Expects a `password` field to be sent to `/login`
|
||||||
- Is checked by the method: `self.check_pass`
|
- Is checked by the method: `self.check_pass`
|
||||||
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from typing import Awaitable, Callable, Optional, Tuple
|
from typing import Awaitable, Callable, Optional, Tuple
|
||||||
|
|
||||||
|
|
|
@ -2064,6 +2064,10 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
|
||||||
[JsonDict, JsonDict],
|
[JsonDict, JsonDict],
|
||||||
Awaitable[Optional[str]],
|
Awaitable[Optional[str]],
|
||||||
]
|
]
|
||||||
|
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
|
||||||
|
[JsonDict, JsonDict],
|
||||||
|
Awaitable[Optional[str]],
|
||||||
|
]
|
||||||
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
|
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
|
||||||
|
|
||||||
|
|
||||||
|
@ -2080,6 +2084,9 @@ class PasswordAuthProvider:
|
||||||
self.get_username_for_registration_callbacks: List[
|
self.get_username_for_registration_callbacks: List[
|
||||||
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
||||||
] = []
|
] = []
|
||||||
|
self.get_displayname_for_registration_callbacks: List[
|
||||||
|
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
|
||||||
|
] = []
|
||||||
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
|
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
|
||||||
|
|
||||||
# Mapping from login type to login parameters
|
# Mapping from login type to login parameters
|
||||||
|
@ -2099,6 +2106,9 @@ class PasswordAuthProvider:
|
||||||
get_username_for_registration: Optional[
|
get_username_for_registration: Optional[
|
||||||
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
||||||
] = None,
|
] = None,
|
||||||
|
get_displayname_for_registration: Optional[
|
||||||
|
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
|
||||||
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Register check_3pid_auth callback
|
# Register check_3pid_auth callback
|
||||||
if check_3pid_auth is not None:
|
if check_3pid_auth is not None:
|
||||||
|
@ -2148,6 +2158,11 @@ class PasswordAuthProvider:
|
||||||
get_username_for_registration,
|
get_username_for_registration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if get_displayname_for_registration is not None:
|
||||||
|
self.get_displayname_for_registration_callbacks.append(
|
||||||
|
get_displayname_for_registration,
|
||||||
|
)
|
||||||
|
|
||||||
if is_3pid_allowed is not None:
|
if is_3pid_allowed is not None:
|
||||||
self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
|
self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
|
||||||
|
|
||||||
|
@ -2350,6 +2365,49 @@ class PasswordAuthProvider:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_displayname_for_registration(
|
||||||
|
self,
|
||||||
|
uia_results: JsonDict,
|
||||||
|
params: JsonDict,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Defines the display name to use when registering the user, using the
|
||||||
|
credentials and parameters provided during the UIA flow.
|
||||||
|
|
||||||
|
Stops at the first callback that returns a tuple containing at least one string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uia_results: The credentials provided during the UIA flow.
|
||||||
|
params: The parameters provided by the registration request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple which first element is the display name, and the second is an MXC URL
|
||||||
|
to the user's avatar.
|
||||||
|
"""
|
||||||
|
for callback in self.get_displayname_for_registration_callbacks:
|
||||||
|
try:
|
||||||
|
res = await callback(uia_results, params)
|
||||||
|
|
||||||
|
if isinstance(res, str):
|
||||||
|
return res
|
||||||
|
elif res is not None:
|
||||||
|
# mypy complains that this line is unreachable because it assumes the
|
||||||
|
# data returned by the module fits the expected type. We just want
|
||||||
|
# to make sure this is the case.
|
||||||
|
logger.warning( # type: ignore[unreachable]
|
||||||
|
"Ignoring non-string value returned by"
|
||||||
|
" get_displayname_for_registration callback %s: %s",
|
||||||
|
callback,
|
||||||
|
res,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Module raised an exception in get_displayname_for_registration: %s",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
raise SynapseError(code=500, msg="Internal Server Error")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def is_3pid_allowed(
|
async def is_3pid_allowed(
|
||||||
self,
|
self,
|
||||||
medium: str,
|
medium: str,
|
||||||
|
|
|
@ -70,6 +70,7 @@ from synapse.handlers.account_validity import (
|
||||||
from synapse.handlers.auth import (
|
from synapse.handlers.auth import (
|
||||||
CHECK_3PID_AUTH_CALLBACK,
|
CHECK_3PID_AUTH_CALLBACK,
|
||||||
CHECK_AUTH_CALLBACK,
|
CHECK_AUTH_CALLBACK,
|
||||||
|
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
|
||||||
GET_USERNAME_FOR_REGISTRATION_CALLBACK,
|
GET_USERNAME_FOR_REGISTRATION_CALLBACK,
|
||||||
IS_3PID_ALLOWED_CALLBACK,
|
IS_3PID_ALLOWED_CALLBACK,
|
||||||
ON_LOGGED_OUT_CALLBACK,
|
ON_LOGGED_OUT_CALLBACK,
|
||||||
|
@ -317,6 +318,9 @@ class ModuleApi:
|
||||||
get_username_for_registration: Optional[
|
get_username_for_registration: Optional[
|
||||||
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
||||||
] = None,
|
] = None,
|
||||||
|
get_displayname_for_registration: Optional[
|
||||||
|
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
|
||||||
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Registers callbacks for password auth provider capabilities.
|
"""Registers callbacks for password auth provider capabilities.
|
||||||
|
|
||||||
|
@ -328,6 +332,7 @@ class ModuleApi:
|
||||||
is_3pid_allowed=is_3pid_allowed,
|
is_3pid_allowed=is_3pid_allowed,
|
||||||
auth_checkers=auth_checkers,
|
auth_checkers=auth_checkers,
|
||||||
get_username_for_registration=get_username_for_registration,
|
get_username_for_registration=get_username_for_registration,
|
||||||
|
get_displayname_for_registration=get_displayname_for_registration,
|
||||||
)
|
)
|
||||||
|
|
||||||
def register_background_update_controller_callbacks(
|
def register_background_update_controller_callbacks(
|
||||||
|
|
|
@ -694,11 +694,18 @@ class RegisterRestServlet(RestServlet):
|
||||||
session_id
|
session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
display_name = await (
|
||||||
|
self.password_auth_provider.get_displayname_for_registration(
|
||||||
|
auth_result, params
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
registered_user_id = await self.registration_handler.register_user(
|
registered_user_id = await self.registration_handler.register_user(
|
||||||
localpart=desired_username,
|
localpart=desired_username,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
guest_access_token=guest_access_token,
|
guest_access_token=guest_access_token,
|
||||||
threepid=threepid,
|
threepid=threepid,
|
||||||
|
default_display_name=display_name,
|
||||||
address=client_addr,
|
address=client_addr,
|
||||||
user_agent_ips=entries,
|
user_agent_ips=entries,
|
||||||
)
|
)
|
||||||
|
|
|
@ -84,7 +84,7 @@ class CustomAuthProvider:
|
||||||
|
|
||||||
def __init__(self, config, api: ModuleApi):
|
def __init__(self, config, api: ModuleApi):
|
||||||
api.register_password_auth_provider_callbacks(
|
api.register_password_auth_provider_callbacks(
|
||||||
auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
|
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_auth(self, *args):
|
def check_auth(self, *args):
|
||||||
|
@ -122,7 +122,7 @@ class PasswordCustomAuthProvider:
|
||||||
auth_checkers={
|
auth_checkers={
|
||||||
("test.login_type", ("test_field",)): self.check_auth,
|
("test.login_type", ("test_field",)): self.check_auth,
|
||||||
("m.login.password", ("password",)): self.check_auth,
|
("m.login.password", ("password",)): self.check_auth,
|
||||||
},
|
}
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
account.register_servlets,
|
account.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
CALLBACK_USERNAME = "get_username_for_registration"
|
||||||
|
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# we use a global mock device, so make sure we are starting with a clean slate
|
# we use a global mock device, so make sure we are starting with a clean slate
|
||||||
mock_password_provider.reset_mock()
|
mock_password_provider.reset_mock()
|
||||||
|
@ -754,7 +757,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"""Tests that the get_username_for_registration callback can define the username
|
"""Tests that the get_username_for_registration callback can define the username
|
||||||
of a user when registering.
|
of a user when registering.
|
||||||
"""
|
"""
|
||||||
self._setup_get_username_for_registration()
|
self._setup_get_name_for_registration(
|
||||||
|
callback_name=self.CALLBACK_USERNAME,
|
||||||
|
)
|
||||||
|
|
||||||
username = "rin"
|
username = "rin"
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -777,30 +782,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"""Tests that the get_username_for_registration callback is only called at the
|
"""Tests that the get_username_for_registration callback is only called at the
|
||||||
end of the UIA flow.
|
end of the UIA flow.
|
||||||
"""
|
"""
|
||||||
m = self._setup_get_username_for_registration()
|
m = self._setup_get_name_for_registration(
|
||||||
|
callback_name=self.CALLBACK_USERNAME,
|
||||||
|
)
|
||||||
|
|
||||||
# Initiate the UIA flow.
|
|
||||||
username = "rin"
|
username = "rin"
|
||||||
channel = self.make_request(
|
res = self._do_uia_assert_mock_not_called(username, m)
|
||||||
"POST",
|
|
||||||
"register",
|
|
||||||
{"username": username, "type": "m.login.password", "password": "bar"},
|
|
||||||
)
|
|
||||||
self.assertEqual(channel.code, 401)
|
|
||||||
self.assertIn("session", channel.json_body)
|
|
||||||
|
|
||||||
# Check that the callback hasn't been called yet.
|
mxid = res["user_id"]
|
||||||
m.assert_not_called()
|
|
||||||
|
|
||||||
# Finish the UIA flow.
|
|
||||||
session = channel.json_body["session"]
|
|
||||||
channel = self.make_request(
|
|
||||||
"POST",
|
|
||||||
"register",
|
|
||||||
{"auth": {"session": session, "type": LoginType.DUMMY}},
|
|
||||||
)
|
|
||||||
self.assertEqual(channel.code, 200, channel.json_body)
|
|
||||||
mxid = channel.json_body["user_id"]
|
|
||||||
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
|
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
|
||||||
|
|
||||||
# Check that the callback has been called.
|
# Check that the callback has been called.
|
||||||
|
@ -817,6 +806,56 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self._test_3pid_allowed("rin", False)
|
self._test_3pid_allowed("rin", False)
|
||||||
self._test_3pid_allowed("kitay", True)
|
self._test_3pid_allowed("kitay", True)
|
||||||
|
|
||||||
|
def test_displayname(self):
|
||||||
|
"""Tests that the get_displayname_for_registration callback can define the
|
||||||
|
display name of a user when registering.
|
||||||
|
"""
|
||||||
|
self._setup_get_name_for_registration(
|
||||||
|
callback_name=self.CALLBACK_DISPLAYNAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
username = "rin"
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/register",
|
||||||
|
{
|
||||||
|
"username": username,
|
||||||
|
"password": "bar",
|
||||||
|
"auth": {"type": LoginType.DUMMY},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
|
# Our callback takes the username and appends "-foo" to it, check that's what we
|
||||||
|
# have.
|
||||||
|
user_id = UserID.from_string(channel.json_body["user_id"])
|
||||||
|
display_name = self.get_success(
|
||||||
|
self.hs.get_profile_handler().get_displayname(user_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(display_name, username + "-foo")
|
||||||
|
|
||||||
|
def test_displayname_uia(self):
|
||||||
|
"""Tests that the get_displayname_for_registration callback is only called at the
|
||||||
|
end of the UIA flow.
|
||||||
|
"""
|
||||||
|
m = self._setup_get_name_for_registration(
|
||||||
|
callback_name=self.CALLBACK_DISPLAYNAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
username = "rin"
|
||||||
|
res = self._do_uia_assert_mock_not_called(username, m)
|
||||||
|
|
||||||
|
user_id = UserID.from_string(res["user_id"])
|
||||||
|
display_name = self.get_success(
|
||||||
|
self.hs.get_profile_handler().get_displayname(user_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(display_name, username + "-foo")
|
||||||
|
|
||||||
|
# Check that the callback has been called.
|
||||||
|
m.assert_called_once()
|
||||||
|
|
||||||
def _test_3pid_allowed(self, username: str, registration: bool):
|
def _test_3pid_allowed(self, username: str, registration: bool):
|
||||||
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
|
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
|
||||||
either /register or /account URLs depending on the arguments.
|
either /register or /account URLs depending on the arguments.
|
||||||
|
@ -877,23 +916,47 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
m.assert_called_once_with("email", "bar@test.com", registration)
|
m.assert_called_once_with("email", "bar@test.com", registration)
|
||||||
|
|
||||||
def _setup_get_username_for_registration(self) -> Mock:
|
def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
|
||||||
"""Registers a get_username_for_registration callback that appends "-foo" to the
|
"""Registers either a get_username_for_registration callback or a
|
||||||
username the client is trying to register.
|
get_displayname_for_registration callback that appends "-foo" to the username the
|
||||||
|
client is trying to register.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def get_username_for_registration(uia_results, params):
|
async def callback(uia_results, params):
|
||||||
self.assertIn(LoginType.DUMMY, uia_results)
|
self.assertIn(LoginType.DUMMY, uia_results)
|
||||||
username = params["username"]
|
username = params["username"]
|
||||||
return username + "-foo"
|
return username + "-foo"
|
||||||
|
|
||||||
m = Mock(side_effect=get_username_for_registration)
|
m = Mock(side_effect=callback)
|
||||||
|
|
||||||
password_auth_provider = self.hs.get_password_auth_provider()
|
password_auth_provider = self.hs.get_password_auth_provider()
|
||||||
password_auth_provider.get_username_for_registration_callbacks.append(m)
|
getattr(password_auth_provider, callback_name + "_callbacks").append(m)
|
||||||
|
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
|
||||||
|
# Initiate the UIA flow.
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"register",
|
||||||
|
{"username": username, "type": "m.login.password", "password": "bar"},
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 401)
|
||||||
|
self.assertIn("session", channel.json_body)
|
||||||
|
|
||||||
|
# Check that the callback hasn't been called yet.
|
||||||
|
m.assert_not_called()
|
||||||
|
|
||||||
|
# Finish the UIA flow.
|
||||||
|
session = channel.json_body["session"]
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"register",
|
||||||
|
{"auth": {"session": session, "type": LoginType.DUMMY}},
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
return channel.json_body
|
||||||
|
|
||||||
def _get_login_flows(self) -> JsonDict:
|
def _get_login_flows(self) -> JsonDict:
|
||||||
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
Loading…
Reference in New Issue