Add a callback to allow modules to deny 3PID (#11854)

Part of the Tchap Synapse mainlining.

This allows modules to implement extra logic to figure out whether a given 3PID can be added to the local homeserver. In the Tchap use case, this will allow a Synapse module to interface with the custom endpoint /internal_info.
pull/11939/head
Brendan Abolivier 2022-02-08 11:20:32 +01:00 committed by GitHub
parent fef2e792be
commit 0640f8ebaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 161 additions and 7 deletions

View File

@ -0,0 +1 @@
Add a callback to allow modules to allow or forbid a 3PID (email address, phone number) from being associated to a local account.

View File

@ -166,6 +166,25 @@ any of the subsequent implementations of this callback. If every callback return
the username provided by the user is used, if any (otherwise one is automatically the username provided by the user is used, if any (otherwise one is automatically
generated). generated).
## `is_3pid_allowed`
_First introduced in Synapse v1.53.0_
```python
async def is_3pid_allowed(self, medium: str, address: str, registration: bool) -> bool
```
Called when attempting to bind a third-party identifier (i.e. an email address or a phone
number). The module is given the medium of the third-party identifier (which is `email` if
the identifier is an email address, or `msisdn` if the identifier is a phone number) and
its address, as well as a boolean indicating whether the attempt to bind is happening as
part of registering a new user. The module must return a boolean indicating whether the
identifier can be allowed to be bound to an account on the local homeserver.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `True`, Synapse falls through to the next one. The value of the first
callback that does not return `True` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback.
## Example ## Example

View File

@ -2064,6 +2064,7 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict], [JsonDict, JsonDict],
Awaitable[Optional[str]], Awaitable[Optional[str]],
] ]
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
class PasswordAuthProvider: class PasswordAuthProvider:
@ -2079,6 +2080,7 @@ class PasswordAuthProvider:
self.get_username_for_registration_callbacks: List[ self.get_username_for_registration_callbacks: List[
GET_USERNAME_FOR_REGISTRATION_CALLBACK GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = [] ] = []
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
# Mapping from login type to login parameters # Mapping from login type to login parameters
self._supported_login_types: Dict[str, Iterable[str]] = {} self._supported_login_types: Dict[str, Iterable[str]] = {}
@ -2090,6 +2092,7 @@ class PasswordAuthProvider:
self, self,
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
auth_checkers: Optional[ auth_checkers: Optional[
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
] = None, ] = None,
@ -2145,6 +2148,9 @@ class PasswordAuthProvider:
get_username_for_registration, get_username_for_registration,
) )
if is_3pid_allowed is not None:
self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider """Get the login types supported by this password provider
@ -2343,3 +2349,41 @@ class PasswordAuthProvider:
raise SynapseError(code=500, msg="Internal Server Error") raise SynapseError(code=500, msg="Internal Server Error")
return None return None
async def is_3pid_allowed(
self,
medium: str,
address: str,
registration: bool,
) -> bool:
"""Check if the user can be allowed to bind a 3PID on this homeserver.
Args:
medium: The medium of the 3PID.
address: The address of the 3PID.
registration: Whether the 3PID is being bound when registering a new user.
Returns:
Whether the 3PID is allowed to be bound on this homeserver
"""
for callback in self.is_3pid_allowed_callbacks:
try:
res = await callback(medium, address, registration)
if res is False:
return res
elif not isinstance(res, bool):
# mypy complains that this line is unreachable because it assumes the
# data returned by the module fits the expected type. We just want
# to make sure this is the case.
logger.warning( # type: ignore[unreachable]
"Ignoring non-string value returned by"
" is_3pid_allowed callback %s: %s",
callback,
res,
)
except Exception as e:
logger.error("Module raised an exception in is_3pid_allowed: %s", e)
raise SynapseError(code=500, msg="Internal Server Error")
return True

View File

@ -72,6 +72,7 @@ from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK, CHECK_3PID_AUTH_CALLBACK,
CHECK_AUTH_CALLBACK, CHECK_AUTH_CALLBACK,
GET_USERNAME_FOR_REGISTRATION_CALLBACK, GET_USERNAME_FOR_REGISTRATION_CALLBACK,
IS_3PID_ALLOWED_CALLBACK,
ON_LOGGED_OUT_CALLBACK, ON_LOGGED_OUT_CALLBACK,
AuthHandler, AuthHandler,
) )
@ -312,6 +313,7 @@ class ModuleApi:
auth_checkers: Optional[ auth_checkers: Optional[
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
] = None, ] = None,
is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
get_username_for_registration: Optional[ get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None, ] = None,
@ -323,6 +325,7 @@ class ModuleApi:
return self._password_auth_provider.register_password_auth_provider_callbacks( return self._password_auth_provider.register_password_auth_provider_callbacks(
check_3pid_auth=check_3pid_auth, check_3pid_auth=check_3pid_auth,
on_logged_out=on_logged_out, on_logged_out=on_logged_out,
is_3pid_allowed=is_3pid_allowed,
auth_checkers=auth_checkers, auth_checkers=auth_checkers,
get_username_for_registration=get_username_for_registration, get_username_for_registration=get_username_for_registration,
) )

View File

@ -385,7 +385,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"] send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param next_link = body.get("next_link") # Optional param
if not check_3pid_allowed(self.hs, "email", email): if not await check_3pid_allowed(self.hs, "email", email):
raise SynapseError( raise SynapseError(
403, 403,
"Your email domain is not authorized on this server", "Your email domain is not authorized on this server",
@ -468,7 +468,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number) msisdn = phone_number_to_msisdn(country, phone_number)
if not check_3pid_allowed(self.hs, "msisdn", msisdn): if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError( raise SynapseError(
403, 403,
"Account phone numbers are not authorized on this server", "Account phone numbers are not authorized on this server",

View File

@ -112,7 +112,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"] send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param next_link = body.get("next_link") # Optional param
if not check_3pid_allowed(self.hs, "email", email): if not await check_3pid_allowed(self.hs, "email", email, registration=True):
raise SynapseError( raise SynapseError(
403, 403,
"Your email domain is not authorized to register on this server", "Your email domain is not authorized to register on this server",
@ -192,7 +192,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number) msisdn = phone_number_to_msisdn(country, phone_number)
if not check_3pid_allowed(self.hs, "msisdn", msisdn): if not await check_3pid_allowed(self.hs, "msisdn", msisdn, registration=True):
raise SynapseError( raise SynapseError(
403, 403,
"Phone numbers are not authorized to register on this server", "Phone numbers are not authorized to register on this server",
@ -616,7 +616,9 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"] medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"] address = auth_result[login_type]["address"]
if not check_3pid_allowed(self.hs, medium, address): if not await check_3pid_allowed(
self.hs, medium, address, registration=True
):
raise SynapseError( raise SynapseError(
403, 403,
"Third party identifiers (email/phone numbers)" "Third party identifiers (email/phone numbers)"

View File

@ -32,7 +32,12 @@ logger = logging.getLogger(__name__)
MAX_EMAIL_ADDRESS_LENGTH = 500 MAX_EMAIL_ADDRESS_LENGTH = 500
def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: async def check_3pid_allowed(
hs: "HomeServer",
medium: str,
address: str,
registration: bool = False,
) -> bool:
"""Checks whether a given format of 3PID is allowed to be used on this HS """Checks whether a given format of 3PID is allowed to be used on this HS
Args: Args:
@ -40,9 +45,15 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
medium: 3pid medium - e.g. email, msisdn medium: 3pid medium - e.g. email, msisdn
address: address within that medium (e.g. "wotan@matrix.org") address: address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised msisdns need to first have been canonicalised
registration: whether we want to bind the 3PID as part of registering a new user.
Returns: Returns:
bool: whether the 3PID medium/address is allowed to be added to this HS bool: whether the 3PID medium/address is allowed to be added to this HS
""" """
if not await hs.get_password_auth_provider().is_3pid_allowed(
medium, address, registration
):
return False
if hs.config.registration.allowed_local_3pids: if hs.config.registration.allowed_local_3pids:
for constraint in hs.config.registration.allowed_local_3pids: for constraint in hs.config.registration.allowed_local_3pids:

View File

@ -21,13 +21,15 @@ from twisted.internet import defer
import synapse import synapse
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login, logout, register from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
from tests import unittest from tests import unittest
from tests.server import FakeChannel from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
# (possibly experimental) login flows we expect to appear in the list after the normal # (possibly experimental) login flows we expect to appear in the list after the normal
@ -158,6 +160,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
devices.register_servlets, devices.register_servlets,
logout.register_servlets, logout.register_servlets,
register.register_servlets, register.register_servlets,
account.register_servlets,
] ]
def setUp(self): def setUp(self):
@ -803,6 +806,77 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Check that the callback has been called. # Check that the callback has been called.
m.assert_called_once() m.assert_called_once()
# Set some email configuration so the test doesn't fail because of its absence.
@override_config({"email": {"notif_from": "noreply@test"}})
def test_3pid_allowed(self):
"""Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
the 3PID. Also checks that the module is passed a boolean indicating whether the
user to bind this 3PID to is currently registering.
"""
self._test_3pid_allowed("rin", False)
self._test_3pid_allowed("kitay", True)
def _test_3pid_allowed(self, username: str, registration: bool):
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
either /register or /account URLs depending on the arguments.
Args:
username: The username to use for the test.
registration: Whether to test with registration URLs.
"""
self.hs.get_identity_handler().send_threepid_validation = Mock(
return_value=make_awaitable(0),
)
m = Mock(return_value=make_awaitable(False))
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
self.register_user(username, "password")
tok = self.login(username, "password")
if registration:
url = "/register/email/requestToken"
else:
url = "/account/3pid/email/requestToken"
channel = self.make_request(
"POST",
url,
{
"client_secret": "foo",
"email": "foo@test.com",
"send_attempt": 0,
},
access_token=tok,
)
self.assertEqual(channel.code, 403, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.THREEPID_DENIED,
channel.json_body,
)
m.assert_called_once_with("email", "foo@test.com", registration)
m = Mock(return_value=make_awaitable(True))
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
channel = self.make_request(
"POST",
url,
{
"client_secret": "foo",
"email": "bar@test.com",
"send_attempt": 0,
},
access_token=tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertIn("sid", channel.json_body)
m.assert_called_once_with("email", "bar@test.com", registration)
def _setup_get_username_for_registration(self) -> Mock: def _setup_get_username_for_registration(self) -> Mock:
"""Registers a get_username_for_registration callback that appends "-foo" to the """Registers a get_username_for_registration callback that appends "-foo" to the
username the client is trying to register. username the client is trying to register.