Port the Password Auth Providers module interface to the new generic interface (#10548)
Co-authored-by: Azrenbeth <7782548+Azrenbeth@users.noreply.github.com> Co-authored-by: Brendan Abolivier <babolivier@matrix.org>pull/11069/head
parent
732bbf6737
commit
cdd308845b
|
@ -0,0 +1 @@
|
||||||
|
Port the Password Auth Providers module interface to the new generic interface.
|
|
@ -43,6 +43,7 @@
|
||||||
- [Third-party rules callbacks](modules/third_party_rules_callbacks.md)
|
- [Third-party rules callbacks](modules/third_party_rules_callbacks.md)
|
||||||
- [Presence router callbacks](modules/presence_router_callbacks.md)
|
- [Presence router callbacks](modules/presence_router_callbacks.md)
|
||||||
- [Account validity callbacks](modules/account_validity_callbacks.md)
|
- [Account validity callbacks](modules/account_validity_callbacks.md)
|
||||||
|
- [Password auth provider callbacks](modules/password_auth_provider_callbacks.md)
|
||||||
- [Porting a legacy module to the new interface](modules/porting_legacy_module.md)
|
- [Porting a legacy module to the new interface](modules/porting_legacy_module.md)
|
||||||
- [Workers](workers.md)
|
- [Workers](workers.md)
|
||||||
- [Using `synctl` with Workers](synctl_workers.md)
|
- [Using `synctl` with Workers](synctl_workers.md)
|
||||||
|
|
|
@ -0,0 +1,153 @@
|
||||||
|
# Password auth provider callbacks
|
||||||
|
|
||||||
|
Password auth providers offer a way for server administrators to integrate
|
||||||
|
their Synapse installation with an external authentication system. The callbacks can be
|
||||||
|
registered by using the Module API's `register_password_auth_provider_callbacks` method.
|
||||||
|
|
||||||
|
## Callbacks
|
||||||
|
|
||||||
|
### `auth_checkers`
|
||||||
|
|
||||||
|
```
|
||||||
|
auth_checkers: Dict[Tuple[str,Tuple], Callable]
|
||||||
|
```
|
||||||
|
|
||||||
|
A dict mapping from tuples of a login type identifier (such as `m.login.password`) and a
|
||||||
|
tuple of field names (such as `("password", "secret_thing")`) to authentication checking
|
||||||
|
callbacks, which should be of the following form:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def check_auth(
|
||||||
|
user: str,
|
||||||
|
login_type: str,
|
||||||
|
login_dict: "synapse.module_api.JsonDict",
|
||||||
|
) -> Optional[
|
||||||
|
Tuple[
|
||||||
|
str,
|
||||||
|
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
The login type and field names should be provided by the user in the
|
||||||
|
request to the `/login` API. [The Matrix specification](https://matrix.org/docs/spec/client_server/latest#authentication-types)
|
||||||
|
defines some types, however user defined ones are also allowed.
|
||||||
|
|
||||||
|
The callback is passed the `user` field provided by the client (which might not be in
|
||||||
|
`@username:server` form), the login type, and a dictionary of login secrets passed by
|
||||||
|
the client.
|
||||||
|
|
||||||
|
If the authentication is successful, the module must return the user's Matrix ID (e.g.
|
||||||
|
`@alice:example.com`) and optionally a callback to be called with the response to the
|
||||||
|
`/login` request. If the module doesn't wish to return a callback, it must return `None`
|
||||||
|
instead.
|
||||||
|
|
||||||
|
If the authentication is unsuccessful, the module must return `None`.
|
||||||
|
|
||||||
|
### `check_3pid_auth`
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def check_3pid_auth(
|
||||||
|
medium: str,
|
||||||
|
address: str,
|
||||||
|
password: str,
|
||||||
|
) -> Optional[
|
||||||
|
Tuple[
|
||||||
|
str,
|
||||||
|
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Called when a user attempts to register or log in with a third party identifier,
|
||||||
|
such as email. It is passed the medium (eg. `email`), an address (eg. `jdoe@example.com`)
|
||||||
|
and the user's password.
|
||||||
|
|
||||||
|
If the authentication is successful, the module must return the user's Matrix ID (e.g.
|
||||||
|
`@alice:example.com`) and optionally a callback to be called with the response to the `/login` request.
|
||||||
|
If the module doesn't wish to return a callback, it must return None instead.
|
||||||
|
|
||||||
|
If the authentication is unsuccessful, the module must return None.
|
||||||
|
|
||||||
|
### `on_logged_out`
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def on_logged_out(
|
||||||
|
user_id: str,
|
||||||
|
device_id: Optional[str],
|
||||||
|
access_token: str
|
||||||
|
) -> None
|
||||||
|
```
|
||||||
|
Called during a logout request for a user. It is passed the qualified user ID, the ID of the
|
||||||
|
deactivated device (if any: access tokens are occasionally created without an associated
|
||||||
|
device ID), and the (now deactivated) access token.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
The example module below implements authentication checkers for two different login types:
|
||||||
|
- `my.login.type`
|
||||||
|
- Expects a `my_field` field to be sent to `/login`
|
||||||
|
- Is checked by the method: `self.check_my_login`
|
||||||
|
- `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based))
|
||||||
|
- Expects a `password` field to be sent to `/login`
|
||||||
|
- Is checked by the method: `self.check_pass`
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Awaitable, Callable, Optional, Tuple
|
||||||
|
|
||||||
|
import synapse
|
||||||
|
from synapse import module_api
|
||||||
|
|
||||||
|
|
||||||
|
class MyAuthProvider:
|
||||||
|
def __init__(self, config: dict, api: module_api):
|
||||||
|
|
||||||
|
self.api = api
|
||||||
|
|
||||||
|
self.credentials = {
|
||||||
|
"bob": "building",
|
||||||
|
"@scoop:matrix.org": "digging",
|
||||||
|
}
|
||||||
|
|
||||||
|
api.register_password_auth_provider_callbacks(
|
||||||
|
auth_checkers={
|
||||||
|
("my.login_type", ("my_field",)): self.check_my_login,
|
||||||
|
("m.login.password", ("password",)): self.check_pass,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def check_my_login(
|
||||||
|
self,
|
||||||
|
username: str,
|
||||||
|
login_type: str,
|
||||||
|
login_dict: "synapse.module_api.JsonDict",
|
||||||
|
) -> Optional[
|
||||||
|
Tuple[
|
||||||
|
str,
|
||||||
|
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
|
||||||
|
]
|
||||||
|
]:
|
||||||
|
if login_type != "my.login_type":
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.credentials.get(username) == login_dict.get("my_field"):
|
||||||
|
return self.api.get_qualified_user_id(username)
|
||||||
|
|
||||||
|
async def check_pass(
|
||||||
|
self,
|
||||||
|
username: str,
|
||||||
|
login_type: str,
|
||||||
|
login_dict: "synapse.module_api.JsonDict",
|
||||||
|
) -> Optional[
|
||||||
|
Tuple[
|
||||||
|
str,
|
||||||
|
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
|
||||||
|
]
|
||||||
|
]:
|
||||||
|
if login_type != "m.login.password":
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.credentials.get(username) == login_dict.get("password"):
|
||||||
|
return self.api.get_qualified_user_id(username)
|
||||||
|
```
|
|
@ -12,6 +12,9 @@ should register this resource in its `__init__` method using the `register_web_r
|
||||||
method from the `ModuleApi` class (see [this section](writing_a_module.html#registering-a-web-resource) for
|
method from the `ModuleApi` class (see [this section](writing_a_module.html#registering-a-web-resource) for
|
||||||
more info).
|
more info).
|
||||||
|
|
||||||
|
There is no longer a `get_db_schema_files` callback provided for password auth provider modules. Any
|
||||||
|
changes to the database should now be made by the module using the module API class.
|
||||||
|
|
||||||
The module's author should also update any example in the module's configuration to only
|
The module's author should also update any example in the module's configuration to only
|
||||||
use the new `modules` section in Synapse's configuration file (see [this section](index.html#using-modules)
|
use the new `modules` section in Synapse's configuration file (see [this section](index.html#using-modules)
|
||||||
for more info).
|
for more info).
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
<h2 style="color:red">
|
||||||
|
This page of the Synapse documentation is now deprecated. For up to date
|
||||||
|
documentation on setting up or writing a password auth provider module, please see
|
||||||
|
<a href="modules.md">this page</a>.
|
||||||
|
</h2>
|
||||||
|
|
||||||
# Password auth provider modules
|
# Password auth provider modules
|
||||||
|
|
||||||
Password auth providers offer a way for server administrators to
|
Password auth providers offer a way for server administrators to
|
||||||
|
|
|
@ -2260,34 +2260,6 @@ email:
|
||||||
#email_validation: "[%(server_name)s] Validate your email"
|
#email_validation: "[%(server_name)s] Validate your email"
|
||||||
|
|
||||||
|
|
||||||
# Password providers allow homeserver administrators to integrate
|
|
||||||
# their Synapse installation with existing authentication methods
|
|
||||||
# ex. LDAP, external tokens, etc.
|
|
||||||
#
|
|
||||||
# For more information and known implementations, please see
|
|
||||||
# https://matrix-org.github.io/synapse/latest/password_auth_providers.html
|
|
||||||
#
|
|
||||||
# Note: instances wishing to use SAML or CAS authentication should
|
|
||||||
# instead use the `saml2_config` or `cas_config` options,
|
|
||||||
# respectively.
|
|
||||||
#
|
|
||||||
password_providers:
|
|
||||||
# # Example config for an LDAP auth provider
|
|
||||||
# - module: "ldap_auth_provider.LdapAuthProvider"
|
|
||||||
# config:
|
|
||||||
# enabled: true
|
|
||||||
# uri: "ldap://ldap.example.com:389"
|
|
||||||
# start_tls: true
|
|
||||||
# base: "ou=users,dc=example,dc=com"
|
|
||||||
# attributes:
|
|
||||||
# uid: "cn"
|
|
||||||
# mail: "email"
|
|
||||||
# name: "givenName"
|
|
||||||
# #bind_dn:
|
|
||||||
# #bind_password:
|
|
||||||
# #filter: "(objectClass=posixAccount)"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Push ##
|
## Push ##
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,7 @@ from synapse.crypto import context_factory
|
||||||
from synapse.events.presence_router import load_legacy_presence_router
|
from synapse.events.presence_router import load_legacy_presence_router
|
||||||
from synapse.events.spamcheck import load_legacy_spam_checkers
|
from synapse.events.spamcheck import load_legacy_spam_checkers
|
||||||
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
||||||
|
from synapse.handlers.auth import load_legacy_password_auth_providers
|
||||||
from synapse.logging.context import PreserveLoggingContext
|
from synapse.logging.context import PreserveLoggingContext
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.metrics.jemalloc import setup_jemalloc_stats
|
from synapse.metrics.jemalloc import setup_jemalloc_stats
|
||||||
|
@ -379,6 +380,7 @@ async def start(hs: "HomeServer"):
|
||||||
load_legacy_spam_checkers(hs)
|
load_legacy_spam_checkers(hs)
|
||||||
load_legacy_third_party_event_rules(hs)
|
load_legacy_third_party_event_rules(hs)
|
||||||
load_legacy_presence_router(hs)
|
load_legacy_presence_router(hs)
|
||||||
|
load_legacy_password_auth_providers(hs)
|
||||||
|
|
||||||
# If we've configured an expiry time for caches, start the background job now.
|
# If we've configured an expiry time for caches, start the background job now.
|
||||||
setup_expire_lru_cache_entries(hs)
|
setup_expire_lru_cache_entries(hs)
|
||||||
|
|
|
@ -25,6 +25,29 @@ class PasswordAuthProviderConfig(Config):
|
||||||
section = "authproviders"
|
section = "authproviders"
|
||||||
|
|
||||||
def read_config(self, config, **kwargs):
|
def read_config(self, config, **kwargs):
|
||||||
|
"""Parses the old password auth providers config. The config format looks like this:
|
||||||
|
|
||||||
|
password_providers:
|
||||||
|
# Example config for an LDAP auth provider
|
||||||
|
- module: "ldap_auth_provider.LdapAuthProvider"
|
||||||
|
config:
|
||||||
|
enabled: true
|
||||||
|
uri: "ldap://ldap.example.com:389"
|
||||||
|
start_tls: true
|
||||||
|
base: "ou=users,dc=example,dc=com"
|
||||||
|
attributes:
|
||||||
|
uid: "cn"
|
||||||
|
mail: "email"
|
||||||
|
name: "givenName"
|
||||||
|
#bind_dn:
|
||||||
|
#bind_password:
|
||||||
|
#filter: "(objectClass=posixAccount)"
|
||||||
|
|
||||||
|
We expect admins to use modules for this feature (which is why it doesn't appear
|
||||||
|
in the sample config file), but we want to keep support for it around for a bit
|
||||||
|
for backwards compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
self.password_providers: List[Tuple[Type, Any]] = []
|
self.password_providers: List[Tuple[Type, Any]] = []
|
||||||
providers = []
|
providers = []
|
||||||
|
|
||||||
|
@ -49,33 +72,3 @@ class PasswordAuthProviderConfig(Config):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.password_providers.append((provider_class, provider_config))
|
self.password_providers.append((provider_class, provider_config))
|
||||||
|
|
||||||
def generate_config_section(self, **kwargs):
|
|
||||||
return """\
|
|
||||||
# Password providers allow homeserver administrators to integrate
|
|
||||||
# their Synapse installation with existing authentication methods
|
|
||||||
# ex. LDAP, external tokens, etc.
|
|
||||||
#
|
|
||||||
# For more information and known implementations, please see
|
|
||||||
# https://matrix-org.github.io/synapse/latest/password_auth_providers.html
|
|
||||||
#
|
|
||||||
# Note: instances wishing to use SAML or CAS authentication should
|
|
||||||
# instead use the `saml2_config` or `cas_config` options,
|
|
||||||
# respectively.
|
|
||||||
#
|
|
||||||
password_providers:
|
|
||||||
# # Example config for an LDAP auth provider
|
|
||||||
# - module: "ldap_auth_provider.LdapAuthProvider"
|
|
||||||
# config:
|
|
||||||
# enabled: true
|
|
||||||
# uri: "ldap://ldap.example.com:389"
|
|
||||||
# start_tls: true
|
|
||||||
# base: "ou=users,dc=example,dc=com"
|
|
||||||
# attributes:
|
|
||||||
# uid: "cn"
|
|
||||||
# mail: "email"
|
|
||||||
# name: "givenName"
|
|
||||||
# #bind_dn:
|
|
||||||
# #bind_password:
|
|
||||||
# #filter: "(objectClass=posixAccount)"
|
|
||||||
"""
|
|
||||||
|
|
|
@ -200,46 +200,13 @@ class AuthHandler:
|
||||||
|
|
||||||
self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
|
self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
|
||||||
|
|
||||||
# we can't use hs.get_module_api() here, because to do so will create an
|
self.password_auth_provider = hs.get_password_auth_provider()
|
||||||
# import loop.
|
|
||||||
#
|
|
||||||
# TODO: refactor this class to separate the lower-level stuff that
|
|
||||||
# ModuleApi can use from the higher-level stuff that uses ModuleApi, as
|
|
||||||
# better way to break the loop
|
|
||||||
account_handler = ModuleApi(hs, self)
|
|
||||||
|
|
||||||
self.password_providers = [
|
|
||||||
PasswordProvider.load(module, config, account_handler)
|
|
||||||
for module, config in hs.config.authproviders.password_providers
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info("Extra password_providers: %s", self.password_providers)
|
|
||||||
|
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
self._password_enabled = hs.config.auth.password_enabled
|
self._password_enabled = hs.config.auth.password_enabled
|
||||||
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
|
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
|
||||||
|
|
||||||
# start out by assuming PASSWORD is enabled; we will remove it later if not.
|
|
||||||
login_types = set()
|
|
||||||
if self._password_localdb_enabled:
|
|
||||||
login_types.add(LoginType.PASSWORD)
|
|
||||||
|
|
||||||
for provider in self.password_providers:
|
|
||||||
login_types.update(provider.get_supported_login_types().keys())
|
|
||||||
|
|
||||||
if not self._password_enabled:
|
|
||||||
login_types.discard(LoginType.PASSWORD)
|
|
||||||
|
|
||||||
# Some clients just pick the first type in the list. In this case, we want
|
|
||||||
# them to use PASSWORD (rather than token or whatever), so we want to make sure
|
|
||||||
# that comes first, where it's present.
|
|
||||||
self._supported_login_types = []
|
|
||||||
if LoginType.PASSWORD in login_types:
|
|
||||||
self._supported_login_types.append(LoginType.PASSWORD)
|
|
||||||
login_types.remove(LoginType.PASSWORD)
|
|
||||||
self._supported_login_types.extend(login_types)
|
|
||||||
|
|
||||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||||
# as per `rc_login.failed_attempts`.
|
# as per `rc_login.failed_attempts`.
|
||||||
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
||||||
|
@ -427,11 +394,10 @@ class AuthHandler:
|
||||||
ui_auth_types.add(LoginType.PASSWORD)
|
ui_auth_types.add(LoginType.PASSWORD)
|
||||||
|
|
||||||
# also allow auth from password providers
|
# also allow auth from password providers
|
||||||
for provider in self.password_providers:
|
for t in self.password_auth_provider.get_supported_login_types().keys():
|
||||||
for t in provider.get_supported_login_types().keys():
|
if t == LoginType.PASSWORD and not self._password_enabled:
|
||||||
if t == LoginType.PASSWORD and not self._password_enabled:
|
continue
|
||||||
continue
|
ui_auth_types.add(t)
|
||||||
ui_auth_types.add(t)
|
|
||||||
|
|
||||||
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
|
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
|
||||||
# from sso to mxid.
|
# from sso to mxid.
|
||||||
|
@ -1038,7 +1004,25 @@ class AuthHandler:
|
||||||
Returns:
|
Returns:
|
||||||
login types
|
login types
|
||||||
"""
|
"""
|
||||||
return self._supported_login_types
|
# Load any login types registered by modules
|
||||||
|
# This is stored in the password_auth_provider so this doesn't trigger
|
||||||
|
# any callbacks
|
||||||
|
types = list(self.password_auth_provider.get_supported_login_types().keys())
|
||||||
|
|
||||||
|
# This list should include PASSWORD if (either _password_localdb_enabled is
|
||||||
|
# true or if one of the modules registered it) AND _password_enabled is true
|
||||||
|
# Also:
|
||||||
|
# Some clients just pick the first type in the list. In this case, we want
|
||||||
|
# them to use PASSWORD (rather than token or whatever), so we want to make sure
|
||||||
|
# that comes first, where it's present.
|
||||||
|
if LoginType.PASSWORD in types:
|
||||||
|
types.remove(LoginType.PASSWORD)
|
||||||
|
if self._password_enabled:
|
||||||
|
types.insert(0, LoginType.PASSWORD)
|
||||||
|
elif self._password_localdb_enabled and self._password_enabled:
|
||||||
|
types.insert(0, LoginType.PASSWORD)
|
||||||
|
|
||||||
|
return types
|
||||||
|
|
||||||
async def validate_login(
|
async def validate_login(
|
||||||
self,
|
self,
|
||||||
|
@ -1217,15 +1201,20 @@ class AuthHandler:
|
||||||
|
|
||||||
known_login_type = False
|
known_login_type = False
|
||||||
|
|
||||||
for provider in self.password_providers:
|
# Check if login_type matches a type registered by one of the modules
|
||||||
supported_login_types = provider.get_supported_login_types()
|
# We don't need to remove LoginType.PASSWORD from the list if password login is
|
||||||
if login_type not in supported_login_types:
|
# disabled, since if that were the case then by this point we know that the
|
||||||
# this password provider doesn't understand this login type
|
# login_type is not LoginType.PASSWORD
|
||||||
continue
|
supported_login_types = self.password_auth_provider.get_supported_login_types()
|
||||||
|
# check if the login type being used is supported by a module
|
||||||
|
if login_type in supported_login_types:
|
||||||
|
# Make a note that this login type is supported by the server
|
||||||
known_login_type = True
|
known_login_type = True
|
||||||
|
# Get all the fields expected for this login types
|
||||||
login_fields = supported_login_types[login_type]
|
login_fields = supported_login_types[login_type]
|
||||||
|
|
||||||
|
# go through the login submission and keep track of which required fields are
|
||||||
|
# provided/not provided
|
||||||
missing_fields = []
|
missing_fields = []
|
||||||
login_dict = {}
|
login_dict = {}
|
||||||
for f in login_fields:
|
for f in login_fields:
|
||||||
|
@ -1233,6 +1222,7 @@ class AuthHandler:
|
||||||
missing_fields.append(f)
|
missing_fields.append(f)
|
||||||
else:
|
else:
|
||||||
login_dict[f] = login_submission[f]
|
login_dict[f] = login_submission[f]
|
||||||
|
# raise an error if any of the expected fields for that login type weren't provided
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
|
@ -1240,10 +1230,15 @@ class AuthHandler:
|
||||||
% (login_type, missing_fields),
|
% (login_type, missing_fields),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await provider.check_auth(username, login_type, login_dict)
|
# call all of the check_auth hooks for that login_type
|
||||||
|
# it will return a result once the first success is found (or None otherwise)
|
||||||
|
result = await self.password_auth_provider.check_auth(
|
||||||
|
username, login_type, login_dict
|
||||||
|
)
|
||||||
if result:
|
if result:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# if no module managed to authenticate the user, then fallback to built in password based auth
|
||||||
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
|
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
|
||||||
known_login_type = True
|
known_login_type = True
|
||||||
|
|
||||||
|
@ -1282,11 +1277,16 @@ class AuthHandler:
|
||||||
completed login/registration, or `None`. If authentication was
|
completed login/registration, or `None`. If authentication was
|
||||||
unsuccessful, `user_id` and `callback` are both `None`.
|
unsuccessful, `user_id` and `callback` are both `None`.
|
||||||
"""
|
"""
|
||||||
for provider in self.password_providers:
|
# call all of the check_3pid_auth callbacks
|
||||||
result = await provider.check_3pid_auth(medium, address, password)
|
# Result will be from the first callback that returns something other than None
|
||||||
if result:
|
# If all the callbacks return None, then result is also set to None
|
||||||
return result
|
result = await self.password_auth_provider.check_3pid_auth(
|
||||||
|
medium, address, password
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# if result is None then return (None, None)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
|
async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
|
||||||
|
@ -1365,13 +1365,12 @@ class AuthHandler:
|
||||||
user_info = await self.auth.get_user_by_access_token(access_token)
|
user_info = await self.auth.get_user_by_access_token(access_token)
|
||||||
await self.store.delete_access_token(access_token)
|
await self.store.delete_access_token(access_token)
|
||||||
|
|
||||||
# see if any of our auth providers want to know about this
|
# see if any modules want to know about this
|
||||||
for provider in self.password_providers:
|
await self.password_auth_provider.on_logged_out(
|
||||||
await provider.on_logged_out(
|
user_id=user_info.user_id,
|
||||||
user_id=user_info.user_id,
|
device_id=user_info.device_id,
|
||||||
device_id=user_info.device_id,
|
access_token=access_token,
|
||||||
access_token=access_token,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# delete pushers associated with this access token
|
# delete pushers associated with this access token
|
||||||
if user_info.token_id is not None:
|
if user_info.token_id is not None:
|
||||||
|
@ -1398,12 +1397,11 @@ class AuthHandler:
|
||||||
user_id, except_token_id=except_token_id, device_id=device_id
|
user_id, except_token_id=except_token_id, device_id=device_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# see if any of our auth providers want to know about this
|
# see if any modules want to know about this
|
||||||
for provider in self.password_providers:
|
for token, _, device_id in tokens_and_devices:
|
||||||
for token, _, device_id in tokens_and_devices:
|
await self.password_auth_provider.on_logged_out(
|
||||||
await provider.on_logged_out(
|
user_id=user_id, device_id=device_id, access_token=token
|
||||||
user_id=user_id, device_id=device_id, access_token=token
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# delete pushers associated with the access tokens
|
# delete pushers associated with the access tokens
|
||||||
await self.hs.get_pusherpool().remove_pushers_by_access_token(
|
await self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||||
|
@ -1811,40 +1809,228 @@ class MacaroonGenerator:
|
||||||
return macaroon
|
return macaroon
|
||||||
|
|
||||||
|
|
||||||
class PasswordProvider:
|
def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
|
||||||
"""Wrapper for a password auth provider module
|
module_api = hs.get_module_api()
|
||||||
|
for module, config in hs.config.authproviders.password_providers:
|
||||||
|
load_single_legacy_password_auth_provider(
|
||||||
|
module=module, config=config, api=module_api
|
||||||
|
)
|
||||||
|
|
||||||
This class abstracts out all of the backwards-compatibility hacks for
|
|
||||||
password providers, to provide a consistent interface.
|
def load_single_legacy_password_auth_provider(
|
||||||
|
module: Type, config: JsonDict, api: ModuleApi
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
provider = module(config=config, account_handler=api)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error while initializing %r: %s", module, e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# The known hooks. If a module implements a method who's name appears in this set
|
||||||
|
# we'll want to register it
|
||||||
|
password_auth_provider_methods = {
|
||||||
|
"check_3pid_auth",
|
||||||
|
"on_logged_out",
|
||||||
|
}
|
||||||
|
|
||||||
|
# All methods that the module provides should be async, but this wasn't enforced
|
||||||
|
# in the old module system, so we wrap them if needed
|
||||||
|
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
|
||||||
|
# f might be None if the callback isn't implemented by the module. In this
|
||||||
|
# case we don't want to register a callback at all so we return None.
|
||||||
|
if f is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# We need to wrap check_password because its old form would return a boolean
|
||||||
|
# but we now want it to behave just like check_auth() and return the matrix id of
|
||||||
|
# the user if authentication succeeded or None otherwise
|
||||||
|
if f.__name__ == "check_password":
|
||||||
|
|
||||||
|
async def wrapped_check_password(
|
||||||
|
username: str, login_type: str, login_dict: JsonDict
|
||||||
|
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||||
|
# We've already made sure f is not None above, but mypy doesn't do well
|
||||||
|
# across function boundaries so we need to tell it f is definitely not
|
||||||
|
# None.
|
||||||
|
assert f is not None
|
||||||
|
|
||||||
|
matrix_user_id = api.get_qualified_user_id(username)
|
||||||
|
password = login_dict["password"]
|
||||||
|
|
||||||
|
is_valid = await f(matrix_user_id, password)
|
||||||
|
|
||||||
|
if is_valid:
|
||||||
|
return matrix_user_id, None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
return wrapped_check_password
|
||||||
|
|
||||||
|
# We need to wrap check_auth as in the old form it could return
|
||||||
|
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
|
||||||
|
if f.__name__ == "check_auth":
|
||||||
|
|
||||||
|
async def wrapped_check_auth(
|
||||||
|
username: str, login_type: str, login_dict: JsonDict
|
||||||
|
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||||
|
# We've already made sure f is not None above, but mypy doesn't do well
|
||||||
|
# across function boundaries so we need to tell it f is definitely not
|
||||||
|
# None.
|
||||||
|
assert f is not None
|
||||||
|
|
||||||
|
result = await f(username, login_type, login_dict)
|
||||||
|
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result, None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapped_check_auth
|
||||||
|
|
||||||
|
# We need to wrap check_3pid_auth as in the old form it could return
|
||||||
|
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
|
||||||
|
if f.__name__ == "check_3pid_auth":
|
||||||
|
|
||||||
|
async def wrapped_check_3pid_auth(
|
||||||
|
medium: str, address: str, password: str
|
||||||
|
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||||
|
# We've already made sure f is not None above, but mypy doesn't do well
|
||||||
|
# across function boundaries so we need to tell it f is definitely not
|
||||||
|
# None.
|
||||||
|
assert f is not None
|
||||||
|
|
||||||
|
result = await f(medium, address, password)
|
||||||
|
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result, None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapped_check_3pid_auth
|
||||||
|
|
||||||
|
def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
|
||||||
|
# mypy doesn't do well across function boundaries so we need to tell it
|
||||||
|
# f is definitely not None.
|
||||||
|
assert f is not None
|
||||||
|
|
||||||
|
return maybe_awaitable(f(*args, **kwargs))
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
# populate hooks with the implemented methods, wrapped with async_wrapper
|
||||||
|
hooks = {
|
||||||
|
hook: async_wrapper(getattr(provider, hook, None))
|
||||||
|
for hook in password_auth_provider_methods
|
||||||
|
}
|
||||||
|
|
||||||
|
supported_login_types = {}
|
||||||
|
# call get_supported_login_types and add that to the dict
|
||||||
|
g = getattr(provider, "get_supported_login_types", None)
|
||||||
|
if g is not None:
|
||||||
|
# Note the old module style also called get_supported_login_types at loading time
|
||||||
|
# and it is synchronous
|
||||||
|
supported_login_types.update(g())
|
||||||
|
|
||||||
|
auth_checkers = {}
|
||||||
|
# Legacy modules have a check_auth method which expects to be called with one of
|
||||||
|
# the keys returned by get_supported_login_types. New style modules register a
|
||||||
|
# dictionary of login_type->check_auth_method mappings
|
||||||
|
check_auth = async_wrapper(getattr(provider, "check_auth", None))
|
||||||
|
if check_auth is not None:
|
||||||
|
for login_type, fields in supported_login_types.items():
|
||||||
|
# need tuple(fields) since fields can be any Iterable type (so may not be hashable)
|
||||||
|
auth_checkers[(login_type, tuple(fields))] = check_auth
|
||||||
|
|
||||||
|
# if it has a "check_password" method then it should handle all auth checks
|
||||||
|
# with login type of LoginType.PASSWORD
|
||||||
|
check_password = async_wrapper(getattr(provider, "check_password", None))
|
||||||
|
if check_password is not None:
|
||||||
|
# need to use a tuple here for ("password",) not a list since lists aren't hashable
|
||||||
|
auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
|
||||||
|
|
||||||
|
api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
|
||||||
|
|
||||||
|
|
||||||
|
CHECK_3PID_AUTH_CALLBACK = Callable[
|
||||||
|
[str, str, str],
|
||||||
|
Awaitable[
|
||||||
|
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
|
||||||
|
],
|
||||||
|
]
|
||||||
|
ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
|
||||||
|
CHECK_AUTH_CALLBACK = Callable[
|
||||||
|
[str, str, JsonDict],
|
||||||
|
Awaitable[
|
||||||
|
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordAuthProvider:
|
||||||
|
"""
|
||||||
|
A class that the AuthHandler calls when authenticating users
|
||||||
|
It allows modules to provide alternative methods for authentication
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
def __init__(self) -> None:
|
||||||
def load(
|
# lists of callbacks
|
||||||
cls, module: Type, config: JsonDict, module_api: ModuleApi
|
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
|
||||||
) -> "PasswordProvider":
|
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
|
||||||
try:
|
|
||||||
pp = module(config=config, account_handler=module_api)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error while initializing %r: %s", module, e)
|
|
||||||
raise
|
|
||||||
return cls(pp, module_api)
|
|
||||||
|
|
||||||
def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
|
# Mapping from login type to login parameters
|
||||||
self._pp = pp
|
self._supported_login_types: Dict[str, Iterable[str]] = {}
|
||||||
self._module_api = module_api
|
|
||||||
|
|
||||||
self._supported_login_types = {}
|
# Mapping from login type to auth checker callbacks
|
||||||
|
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
|
||||||
|
|
||||||
# grandfather in check_password support
|
def register_password_auth_provider_callbacks(
|
||||||
if hasattr(self._pp, "check_password"):
|
self,
|
||||||
self._supported_login_types[LoginType.PASSWORD] = ("password",)
|
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
|
||||||
|
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
|
||||||
|
auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
|
||||||
|
) -> None:
|
||||||
|
# Register check_3pid_auth callback
|
||||||
|
if check_3pid_auth is not None:
|
||||||
|
self.check_3pid_auth_callbacks.append(check_3pid_auth)
|
||||||
|
|
||||||
g = getattr(self._pp, "get_supported_login_types", None)
|
# register on_logged_out callback
|
||||||
if g:
|
if on_logged_out is not None:
|
||||||
self._supported_login_types.update(g())
|
self.on_logged_out_callbacks.append(on_logged_out)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
if auth_checkers is not None:
|
||||||
return str(self._pp)
|
# register a new supported login_type
|
||||||
|
# Iterate through all of the types being registered
|
||||||
|
for (login_type, fields), callback in auth_checkers.items():
|
||||||
|
# Note: fields may be empty here. This would allow a modules auth checker to
|
||||||
|
# be called with just 'login_type' and no password or other secrets
|
||||||
|
|
||||||
|
# Need to check that all the field names are strings or may get nasty errors later
|
||||||
|
for f in fields:
|
||||||
|
if not isinstance(f, str):
|
||||||
|
raise RuntimeError(
|
||||||
|
"A module tried to register support for login type: %s with parameters %s"
|
||||||
|
" but all parameter names must be strings"
|
||||||
|
% (login_type, fields)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2 modules supporting the same login type must expect the same fields
|
||||||
|
# e.g. 1 can't expect "pass" if the other expects "password"
|
||||||
|
# so throw an exception if that happens
|
||||||
|
if login_type not in self._supported_login_types.get(login_type, []):
|
||||||
|
self._supported_login_types[login_type] = fields
|
||||||
|
else:
|
||||||
|
fields_currently_supported = self._supported_login_types.get(
|
||||||
|
login_type
|
||||||
|
)
|
||||||
|
if fields_currently_supported != fields:
|
||||||
|
raise RuntimeError(
|
||||||
|
"A module tried to register support for login type: %s with parameters %s"
|
||||||
|
" but another module had already registered support for that type with parameters %s"
|
||||||
|
% (login_type, fields, fields_currently_supported)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the new method to the list of auth_checker_callbacks for this login type
|
||||||
|
self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
|
||||||
|
|
||||||
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
||||||
"""Get the login types supported by this password provider
|
"""Get the login types supported by this password provider
|
||||||
|
@ -1852,20 +2038,15 @@ class PasswordProvider:
|
||||||
Returns a map from a login type identifier (such as m.login.password) to an
|
Returns a map from a login type identifier (such as m.login.password) to an
|
||||||
iterable giving the fields which must be provided by the user in the submission
|
iterable giving the fields which must be provided by the user in the submission
|
||||||
to the /login API.
|
to the /login API.
|
||||||
|
|
||||||
This wrapper adds m.login.password to the list if the underlying password
|
|
||||||
provider supports the check_password() api.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._supported_login_types
|
return self._supported_login_types
|
||||||
|
|
||||||
async def check_auth(
|
async def check_auth(
|
||||||
self, username: str, login_type: str, login_dict: JsonDict
|
self, username: str, login_type: str, login_dict: JsonDict
|
||||||
) -> Optional[Tuple[str, Optional[Callable]]]:
|
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
|
||||||
"""Check if the user has presented valid login credentials
|
"""Check if the user has presented valid login credentials
|
||||||
|
|
||||||
This wrapper also calls check_password() if the underlying password provider
|
|
||||||
supports the check_password() api and the login type is m.login.password.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
username: user id presented by the client. Either an MXID or an unqualified
|
username: user id presented by the client. Either an MXID or an unqualified
|
||||||
username.
|
username.
|
||||||
|
@ -1879,63 +2060,130 @@ class PasswordProvider:
|
||||||
user, and `callback` is an optional callback which will be called with the
|
user, and `callback` is an optional callback which will be called with the
|
||||||
result from the /login call (including access_token, device_id, etc.)
|
result from the /login call (including access_token, device_id, etc.)
|
||||||
"""
|
"""
|
||||||
# first grandfather in a call to check_password
|
|
||||||
if login_type == LoginType.PASSWORD:
|
|
||||||
check_password = getattr(self._pp, "check_password", None)
|
|
||||||
if check_password:
|
|
||||||
qualified_user_id = self._module_api.get_qualified_user_id(username)
|
|
||||||
is_valid = await check_password(
|
|
||||||
qualified_user_id, login_dict["password"]
|
|
||||||
)
|
|
||||||
if is_valid:
|
|
||||||
return qualified_user_id, None
|
|
||||||
|
|
||||||
check_auth = getattr(self._pp, "check_auth", None)
|
# Go through all callbacks for the login type until one returns with a value
|
||||||
if not check_auth:
|
# other than None (i.e. until a callback returns a success)
|
||||||
return None
|
for callback in self.auth_checker_callbacks[login_type]:
|
||||||
result = await check_auth(username, login_type, login_dict)
|
try:
|
||||||
|
result = await callback(username, login_type, login_dict)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to run module API callback %s: %s", callback, e)
|
||||||
|
continue
|
||||||
|
|
||||||
# Check if the return value is a str or a tuple
|
if result is not None:
|
||||||
if isinstance(result, str):
|
# Check that the callback returned a Tuple[str, Optional[Callable]]
|
||||||
# If it's a str, set callback function to None
|
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
|
||||||
return result, None
|
# result is always the right type, but as it is 3rd party code it might not be
|
||||||
|
|
||||||
return result
|
if not isinstance(result, tuple) or len(result) != 2:
|
||||||
|
logger.warning(
|
||||||
|
"Wrong type returned by module API callback %s: %s, expected"
|
||||||
|
" Optional[Tuple[str, Optional[Callable]]]",
|
||||||
|
callback,
|
||||||
|
result,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# pull out the two parts of the tuple so we can do type checking
|
||||||
|
str_result, callback_result = result
|
||||||
|
|
||||||
|
# the 1st item in the tuple should be a str
|
||||||
|
if not isinstance(str_result, str):
|
||||||
|
logger.warning( # type: ignore[unreachable]
|
||||||
|
"Wrong type returned by module API callback %s: %s, expected"
|
||||||
|
" Optional[Tuple[str, Optional[Callable]]]",
|
||||||
|
callback,
|
||||||
|
result,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# the second should be Optional[Callable]
|
||||||
|
if callback_result is not None:
|
||||||
|
if not callable(callback_result):
|
||||||
|
logger.warning( # type: ignore[unreachable]
|
||||||
|
"Wrong type returned by module API callback %s: %s, expected"
|
||||||
|
" Optional[Tuple[str, Optional[Callable]]]",
|
||||||
|
callback,
|
||||||
|
result,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The result is a (str, Optional[callback]) tuple so return the successful result
|
||||||
|
return result
|
||||||
|
|
||||||
|
# If this point has been reached then none of the callbacks successfully authenticated
|
||||||
|
# the user so return None
|
||||||
|
return None
|
||||||
|
|
||||||
async def check_3pid_auth(
|
async def check_3pid_auth(
|
||||||
self, medium: str, address: str, password: str
|
self, medium: str, address: str, password: str
|
||||||
) -> Optional[Tuple[str, Optional[Callable]]]:
|
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
|
||||||
g = getattr(self._pp, "check_3pid_auth", None)
|
|
||||||
if not g:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# This function is able to return a deferred that either
|
# This function is able to return a deferred that either
|
||||||
# resolves None, meaning authentication failure, or upon
|
# resolves None, meaning authentication failure, or upon
|
||||||
# success, to a str (which is the user_id) or a tuple of
|
# success, to a str (which is the user_id) or a tuple of
|
||||||
# (user_id, callback_func), where callback_func should be run
|
# (user_id, callback_func), where callback_func should be run
|
||||||
# after we've finished everything else
|
# after we've finished everything else
|
||||||
result = await g(medium, address, password)
|
|
||||||
|
|
||||||
# Check if the return value is a str or a tuple
|
for callback in self.check_3pid_auth_callbacks:
|
||||||
if isinstance(result, str):
|
try:
|
||||||
# If it's a str, set callback function to None
|
result = await callback(medium, address, password)
|
||||||
return result, None
|
except Exception as e:
|
||||||
|
logger.warning("Failed to run module API callback %s: %s", callback, e)
|
||||||
|
continue
|
||||||
|
|
||||||
return result
|
if result is not None:
|
||||||
|
# Check that the callback returned a Tuple[str, Optional[Callable]]
|
||||||
|
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
|
||||||
|
# result is always the right type, but as it is 3rd party code it might not be
|
||||||
|
|
||||||
|
if not isinstance(result, tuple) or len(result) != 2:
|
||||||
|
logger.warning(
|
||||||
|
"Wrong type returned by module API callback %s: %s, expected"
|
||||||
|
" Optional[Tuple[str, Optional[Callable]]]",
|
||||||
|
callback,
|
||||||
|
result,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# pull out the two parts of the tuple so we can do type checking
|
||||||
|
str_result, callback_result = result
|
||||||
|
|
||||||
|
# the 1st item in the tuple should be a str
|
||||||
|
if not isinstance(str_result, str):
|
||||||
|
logger.warning( # type: ignore[unreachable]
|
||||||
|
"Wrong type returned by module API callback %s: %s, expected"
|
||||||
|
" Optional[Tuple[str, Optional[Callable]]]",
|
||||||
|
callback,
|
||||||
|
result,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# the second should be Optional[Callable]
|
||||||
|
if callback_result is not None:
|
||||||
|
if not callable(callback_result):
|
||||||
|
logger.warning( # type: ignore[unreachable]
|
||||||
|
"Wrong type returned by module API callback %s: %s, expected"
|
||||||
|
" Optional[Tuple[str, Optional[Callable]]]",
|
||||||
|
callback,
|
||||||
|
result,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The result is a (str, Optional[callback]) tuple so return the successful result
|
||||||
|
return result
|
||||||
|
|
||||||
|
# If this point has been reached then none of the callbacks successfully authenticated
|
||||||
|
# the user so return None
|
||||||
|
return None
|
||||||
|
|
||||||
async def on_logged_out(
|
async def on_logged_out(
|
||||||
self, user_id: str, device_id: Optional[str], access_token: str
|
self, user_id: str, device_id: Optional[str], access_token: str
|
||||||
) -> None:
|
) -> None:
|
||||||
g = getattr(self._pp, "on_logged_out", None)
|
|
||||||
if not g:
|
|
||||||
return
|
|
||||||
|
|
||||||
# This might return an awaitable, if it does block the log out
|
# call all of the on_logged_out callbacks
|
||||||
# until it completes.
|
for callback in self.on_logged_out_callbacks:
|
||||||
await maybe_awaitable(
|
try:
|
||||||
g(
|
callback(user_id, device_id, access_token)
|
||||||
user_id=user_id,
|
except Exception as e:
|
||||||
device_id=device_id,
|
logger.warning("Failed to run module API callback %s: %s", callback, e)
|
||||||
access_token=access_token,
|
continue
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
|
@ -45,6 +45,7 @@ from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.rest.client.login import LoginResponse
|
||||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||||
from synapse.storage.databases.main.roommember import ProfileInfo
|
from synapse.storage.databases.main.roommember import ProfileInfo
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
|
@ -83,6 +84,8 @@ __all__ = [
|
||||||
"DirectServeJsonResource",
|
"DirectServeJsonResource",
|
||||||
"ModuleApi",
|
"ModuleApi",
|
||||||
"PRESENCE_ALL_USERS",
|
"PRESENCE_ALL_USERS",
|
||||||
|
"LoginResponse",
|
||||||
|
"JsonDict",
|
||||||
]
|
]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -139,6 +142,7 @@ class ModuleApi:
|
||||||
self._spam_checker = hs.get_spam_checker()
|
self._spam_checker = hs.get_spam_checker()
|
||||||
self._account_validity_handler = hs.get_account_validity_handler()
|
self._account_validity_handler = hs.get_account_validity_handler()
|
||||||
self._third_party_event_rules = hs.get_third_party_event_rules()
|
self._third_party_event_rules = hs.get_third_party_event_rules()
|
||||||
|
self._password_auth_provider = hs.get_password_auth_provider()
|
||||||
self._presence_router = hs.get_presence_router()
|
self._presence_router = hs.get_presence_router()
|
||||||
|
|
||||||
#################################################################################
|
#################################################################################
|
||||||
|
@ -164,6 +168,11 @@ class ModuleApi:
|
||||||
"""Registers callbacks for presence router capabilities."""
|
"""Registers callbacks for presence router capabilities."""
|
||||||
return self._presence_router.register_presence_router_callbacks
|
return self._presence_router.register_presence_router_callbacks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def register_password_auth_provider_callbacks(self):
|
||||||
|
"""Registers callbacks for password auth provider capabilities."""
|
||||||
|
return self._password_auth_provider.register_password_auth_provider_callbacks
|
||||||
|
|
||||||
def register_web_resource(self, path: str, resource: IResource):
|
def register_web_resource(self, path: str, resource: IResource):
|
||||||
"""Registers a web resource to be served at the given path.
|
"""Registers a web resource to be served at the given path.
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,7 @@ from synapse.handlers.account_data import AccountDataHandler
|
||||||
from synapse.handlers.account_validity import AccountValidityHandler
|
from synapse.handlers.account_validity import AccountValidityHandler
|
||||||
from synapse.handlers.admin import AdminHandler
|
from synapse.handlers.admin import AdminHandler
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
|
from synapse.handlers.auth import AuthHandler, MacaroonGenerator, PasswordAuthProvider
|
||||||
from synapse.handlers.cas import CasHandler
|
from synapse.handlers.cas import CasHandler
|
||||||
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||||
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
|
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
|
||||||
|
@ -687,6 +687,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
def get_third_party_event_rules(self) -> ThirdPartyEventRules:
|
def get_third_party_event_rules(self) -> ThirdPartyEventRules:
|
||||||
return ThirdPartyEventRules(self)
|
return ThirdPartyEventRules(self)
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
|
def get_password_auth_provider(self) -> PasswordAuthProvider:
|
||||||
|
return PasswordAuthProvider()
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_room_member_handler(self) -> RoomMemberHandler:
|
def get_room_member_handler(self) -> RoomMemberHandler:
|
||||||
if self.config.worker.worker_app:
|
if self.config.worker.worker_app:
|
||||||
|
|
|
@ -549,6 +549,8 @@ def _apply_module_schemas(
|
||||||
database_engine:
|
database_engine:
|
||||||
config: application config
|
config: application config
|
||||||
"""
|
"""
|
||||||
|
# This is the old way for password_auth_provider modules to make changes
|
||||||
|
# to the database. This should instead be done using the module API
|
||||||
for (mod, _config) in config.authproviders.password_providers:
|
for (mod, _config) in config.authproviders.password_providers:
|
||||||
if not hasattr(mod, "get_db_schema_files"):
|
if not hasattr(mod, "get_db_schema_files"):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -20,6 +20,8 @@ from unittest.mock import Mock
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
|
from synapse.handlers.auth import load_legacy_password_auth_providers
|
||||||
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.rest.client import devices, login
|
from synapse.rest.client import devices, login
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi
|
||||||
mock_password_provider = Mock()
|
mock_password_provider = Mock()
|
||||||
|
|
||||||
|
|
||||||
class PasswordOnlyAuthProvider:
|
class LegacyPasswordOnlyAuthProvider:
|
||||||
"""A password_provider which only implements `check_password`."""
|
"""A legacy password_provider which only implements `check_password`."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_config(self):
|
def parse_config(self):
|
||||||
|
@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider:
|
||||||
return mock_password_provider.check_password(*args)
|
return mock_password_provider.check_password(*args)
|
||||||
|
|
||||||
|
|
||||||
class CustomAuthProvider:
|
class LegacyCustomAuthProvider:
|
||||||
"""A password_provider which implements a custom login type."""
|
"""A legacy password_provider which implements a custom login type."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_config(self):
|
def parse_config(self):
|
||||||
|
@ -67,7 +69,23 @@ class CustomAuthProvider:
|
||||||
return mock_password_provider.check_auth(*args)
|
return mock_password_provider.check_auth(*args)
|
||||||
|
|
||||||
|
|
||||||
class PasswordCustomAuthProvider:
|
class CustomAuthProvider:
|
||||||
|
"""A module which registers password_auth_provider callbacks for a custom login type."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, config, api: ModuleApi):
|
||||||
|
api.register_password_auth_provider_callbacks(
|
||||||
|
auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_auth(self, *args):
|
||||||
|
return mock_password_provider.check_auth(*args)
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyPasswordCustomAuthProvider:
|
||||||
"""A password_provider which implements password login via `check_auth`, as well
|
"""A password_provider which implements password login via `check_auth`, as well
|
||||||
as a custom type."""
|
as a custom type."""
|
||||||
|
|
||||||
|
@ -85,8 +103,32 @@ class PasswordCustomAuthProvider:
|
||||||
return mock_password_provider.check_auth(*args)
|
return mock_password_provider.check_auth(*args)
|
||||||
|
|
||||||
|
|
||||||
def providers_config(*providers: Type[Any]) -> dict:
|
class PasswordCustomAuthProvider:
|
||||||
"""Returns a config dict that will enable the given password auth providers"""
|
"""A module which registers password_auth_provider callbacks for a custom login type.
|
||||||
|
as well as a password login"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, config, api: ModuleApi):
|
||||||
|
api.register_password_auth_provider_callbacks(
|
||||||
|
auth_checkers={
|
||||||
|
("test.login_type", ("test_field",)): self.check_auth,
|
||||||
|
("m.login.password", ("password",)): self.check_auth,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def check_auth(self, *args):
|
||||||
|
return mock_password_provider.check_auth(*args)
|
||||||
|
|
||||||
|
def check_pass(self, *args):
|
||||||
|
return mock_password_provider.check_password(*args)
|
||||||
|
|
||||||
|
|
||||||
|
def legacy_providers_config(*providers: Type[Any]) -> dict:
|
||||||
|
"""Returns a config dict that will enable the given legacy password auth providers"""
|
||||||
return {
|
return {
|
||||||
"password_providers": [
|
"password_providers": [
|
||||||
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
|
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
|
||||||
|
@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def providers_config(*providers: Type[Any]) -> dict:
|
||||||
|
"""Returns a config dict that will enable the given modules"""
|
||||||
|
return {
|
||||||
|
"modules": [
|
||||||
|
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
|
||||||
|
for provider in providers
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
synapse.rest.admin.register_servlets,
|
synapse.rest.admin.register_servlets,
|
||||||
|
@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
mock_password_provider.reset_mock()
|
mock_password_provider.reset_mock()
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
@override_config(providers_config(PasswordOnlyAuthProvider))
|
def make_homeserver(self, reactor, clock):
|
||||||
def test_password_only_auth_provider_login(self):
|
hs = self.setup_test_homeserver()
|
||||||
|
# Load the modules into the homeserver
|
||||||
|
module_api = hs.get_module_api()
|
||||||
|
for module, config in hs.config.modules.loaded_modules:
|
||||||
|
module(config=config, api=module_api)
|
||||||
|
load_legacy_password_auth_providers(hs)
|
||||||
|
|
||||||
|
return hs
|
||||||
|
|
||||||
|
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||||
|
def test_password_only_auth_progiver_login_legacy(self):
|
||||||
|
self.password_only_auth_provider_login_test_body()
|
||||||
|
|
||||||
|
def password_only_auth_provider_login_test_body(self):
|
||||||
# login flows should only have m.login.password
|
# login flows should only have m.login.password
|
||||||
flows = self._get_login_flows()
|
flows = self._get_login_flows()
|
||||||
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
|
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||||
|
@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"@ USER🙂NAME :test", " pASS😢word "
|
"@ USER🙂NAME :test", " pASS😢word "
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config(providers_config(PasswordOnlyAuthProvider))
|
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||||
def test_password_only_auth_provider_ui_auth(self):
|
def test_password_only_auth_provider_ui_auth_legacy(self):
|
||||||
|
self.password_only_auth_provider_ui_auth_test_body()
|
||||||
|
|
||||||
|
def password_only_auth_provider_ui_auth_test_body(self):
|
||||||
"""UI Auth should delegate correctly to the password provider"""
|
"""UI Auth should delegate correctly to the password provider"""
|
||||||
|
|
||||||
# create the user, otherwise access doesn't work
|
# create the user, otherwise access doesn't work
|
||||||
|
@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||||
|
|
||||||
@override_config(providers_config(PasswordOnlyAuthProvider))
|
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||||
def test_local_user_fallback_login(self):
|
def test_local_user_fallback_login_legacy(self):
|
||||||
|
self.local_user_fallback_login_test_body()
|
||||||
|
|
||||||
|
def local_user_fallback_login_test_body(self):
|
||||||
"""rejected login should fall back to local db"""
|
"""rejected login should fall back to local db"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
self.assertEqual("@localuser:test", channel.json_body["user_id"])
|
self.assertEqual("@localuser:test", channel.json_body["user_id"])
|
||||||
|
|
||||||
@override_config(providers_config(PasswordOnlyAuthProvider))
|
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||||
def test_local_user_fallback_ui_auth(self):
|
def test_local_user_fallback_ui_auth_legacy(self):
|
||||||
|
self.local_user_fallback_ui_auth_test_body()
|
||||||
|
|
||||||
|
def local_user_fallback_ui_auth_test_body(self):
|
||||||
"""rejected login should fall back to local db"""
|
"""rejected login should fall back to local db"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
**providers_config(PasswordOnlyAuthProvider),
|
**legacy_providers_config(LegacyPasswordOnlyAuthProvider),
|
||||||
"password_config": {"localdb_enabled": False},
|
"password_config": {"localdb_enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_no_local_user_fallback_login(self):
|
def test_no_local_user_fallback_login_legacy(self):
|
||||||
|
self.no_local_user_fallback_login_test_body()
|
||||||
|
|
||||||
|
def no_local_user_fallback_login_test_body(self):
|
||||||
"""localdb_enabled can block login with the local password"""
|
"""localdb_enabled can block login with the local password"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
**providers_config(PasswordOnlyAuthProvider),
|
**legacy_providers_config(LegacyPasswordOnlyAuthProvider),
|
||||||
"password_config": {"localdb_enabled": False},
|
"password_config": {"localdb_enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_no_local_user_fallback_ui_auth(self):
|
def test_no_local_user_fallback_ui_auth_legacy(self):
|
||||||
|
self.no_local_user_fallback_ui_auth_test_body()
|
||||||
|
|
||||||
|
def no_local_user_fallback_ui_auth_test_body(self):
|
||||||
"""localdb_enabled can block ui auth with the local password"""
|
"""localdb_enabled can block ui auth with the local password"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
**providers_config(PasswordOnlyAuthProvider),
|
**legacy_providers_config(LegacyPasswordOnlyAuthProvider),
|
||||||
"password_config": {"enabled": False},
|
"password_config": {"enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_password_auth_disabled(self):
|
def test_password_auth_disabled_legacy(self):
|
||||||
|
self.password_auth_disabled_test_body()
|
||||||
|
|
||||||
|
def password_auth_disabled_test_body(self):
|
||||||
"""password auth doesn't work if it's disabled across the board"""
|
"""password auth doesn't work if it's disabled across the board"""
|
||||||
# login flows should be empty
|
# login flows should be empty
|
||||||
flows = self._get_login_flows()
|
flows = self._get_login_flows()
|
||||||
|
@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
mock_password_provider.check_password.assert_not_called()
|
mock_password_provider.check_password.assert_not_called()
|
||||||
|
|
||||||
|
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
||||||
|
def test_custom_auth_provider_login_legacy(self):
|
||||||
|
self.custom_auth_provider_login_test_body()
|
||||||
|
|
||||||
@override_config(providers_config(CustomAuthProvider))
|
@override_config(providers_config(CustomAuthProvider))
|
||||||
def test_custom_auth_provider_login(self):
|
def test_custom_auth_provider_login(self):
|
||||||
|
self.custom_auth_provider_login_test_body()
|
||||||
|
|
||||||
|
def custom_auth_provider_login_test_body(self):
|
||||||
# login flows should have the custom flow and m.login.password, since we
|
# login flows should have the custom flow and m.login.password, since we
|
||||||
# haven't disabled local password lookup.
|
# haven't disabled local password lookup.
|
||||||
# (password must come first, because reasons)
|
# (password must come first, because reasons)
|
||||||
|
@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
mock_password_provider.check_auth.assert_not_called()
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
|
||||||
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
|
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||||
|
("@user:bz", None)
|
||||||
|
)
|
||||||
channel = self._send_login("test.login_type", "u", test_field="y")
|
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
||||||
|
@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
# in these cases, but at least we can guard against the API changing
|
# in these cases, but at least we can guard against the API changing
|
||||||
# unexpectedly
|
# unexpectedly
|
||||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||||
"@ MALFORMED! :bz"
|
("@ MALFORMED! :bz", None)
|
||||||
)
|
)
|
||||||
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
|
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
|
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
||||||
|
def test_custom_auth_provider_ui_auth_legacy(self):
|
||||||
|
self.custom_auth_provider_ui_auth_test_body()
|
||||||
|
|
||||||
@override_config(providers_config(CustomAuthProvider))
|
@override_config(providers_config(CustomAuthProvider))
|
||||||
def test_custom_auth_provider_ui_auth(self):
|
def test_custom_auth_provider_ui_auth(self):
|
||||||
|
self.custom_auth_provider_ui_auth_test_body()
|
||||||
|
|
||||||
|
def custom_auth_provider_ui_auth_test_body(self):
|
||||||
# register the user and log in twice, to get two devices
|
# register the user and log in twice, to get two devices
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
tok1 = self.login("localuser", "localpass")
|
tok1 = self.login("localuser", "localpass")
|
||||||
|
@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
mock_password_provider.reset_mock()
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
# right params, but authing as the wrong user
|
# right params, but authing as the wrong user
|
||||||
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
|
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||||
|
("@user:bz", None)
|
||||||
|
)
|
||||||
body["auth"]["test_field"] = "foo"
|
body["auth"]["test_field"] = "foo"
|
||||||
channel = self._delete_device(tok1, "dev2", body)
|
channel = self._delete_device(tok1, "dev2", body)
|
||||||
self.assertEqual(channel.code, 403)
|
self.assertEqual(channel.code, 403)
|
||||||
|
@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# and finally, succeed
|
# and finally, succeed
|
||||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||||
"@localuser:test"
|
("@localuser:test", None)
|
||||||
)
|
)
|
||||||
channel = self._delete_device(tok1, "dev2", body)
|
channel = self._delete_device(tok1, "dev2", body)
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
|
@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"localuser", "test.login_type", {"test_field": "foo"}
|
"localuser", "test.login_type", {"test_field": "foo"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
||||||
|
def test_custom_auth_provider_callback_legacy(self):
|
||||||
|
self.custom_auth_provider_callback_test_body()
|
||||||
|
|
||||||
@override_config(providers_config(CustomAuthProvider))
|
@override_config(providers_config(CustomAuthProvider))
|
||||||
def test_custom_auth_provider_callback(self):
|
def test_custom_auth_provider_callback(self):
|
||||||
|
self.custom_auth_provider_callback_test_body()
|
||||||
|
|
||||||
|
def custom_auth_provider_callback_test_body(self):
|
||||||
callback = Mock(return_value=defer.succeed(None))
|
callback = Mock(return_value=defer.succeed(None))
|
||||||
|
|
||||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||||
|
@ -410,10 +518,22 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
for p in ["user_id", "access_token", "device_id", "home_server"]:
|
for p in ["user_id", "access_token", "device_id", "home_server"]:
|
||||||
self.assertIn(p, call_args[0])
|
self.assertIn(p, call_args[0])
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
**legacy_providers_config(LegacyCustomAuthProvider),
|
||||||
|
"password_config": {"enabled": False},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_custom_auth_password_disabled_legacy(self):
|
||||||
|
self.custom_auth_password_disabled_test_body()
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
|
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
|
||||||
)
|
)
|
||||||
def test_custom_auth_password_disabled(self):
|
def test_custom_auth_password_disabled(self):
|
||||||
|
self.custom_auth_password_disabled_test_body()
|
||||||
|
|
||||||
|
def custom_auth_password_disabled_test_body(self):
|
||||||
"""Test login with a custom auth provider where password login is disabled"""
|
"""Test login with a custom auth provider where password login is disabled"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
@ -425,6 +545,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
mock_password_provider.check_auth.assert_not_called()
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
**legacy_providers_config(LegacyCustomAuthProvider),
|
||||||
|
"password_config": {"enabled": False, "localdb_enabled": False},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
|
||||||
|
self.custom_auth_password_disabled_localdb_enabled_test_body()
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
**providers_config(CustomAuthProvider),
|
**providers_config(CustomAuthProvider),
|
||||||
|
@ -432,6 +561,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_custom_auth_password_disabled_localdb_enabled(self):
|
def test_custom_auth_password_disabled_localdb_enabled(self):
|
||||||
|
self.custom_auth_password_disabled_localdb_enabled_test_body()
|
||||||
|
|
||||||
|
def custom_auth_password_disabled_localdb_enabled_test_body(self):
|
||||||
"""Check the localdb_enabled == enabled == False
|
"""Check the localdb_enabled == enabled == False
|
||||||
|
|
||||||
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
|
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
|
||||||
|
@ -448,6 +580,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
mock_password_provider.check_auth.assert_not_called()
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
**legacy_providers_config(LegacyPasswordCustomAuthProvider),
|
||||||
|
"password_config": {"enabled": False},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_password_custom_auth_password_disabled_login_legacy(self):
|
||||||
|
self.password_custom_auth_password_disabled_login_test_body()
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
**providers_config(PasswordCustomAuthProvider),
|
**providers_config(PasswordCustomAuthProvider),
|
||||||
|
@ -455,6 +596,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_password_custom_auth_password_disabled_login(self):
|
def test_password_custom_auth_password_disabled_login(self):
|
||||||
|
self.password_custom_auth_password_disabled_login_test_body()
|
||||||
|
|
||||||
|
def password_custom_auth_password_disabled_login_test_body(self):
|
||||||
"""log in with a custom auth provider which implements password, but password
|
"""log in with a custom auth provider which implements password, but password
|
||||||
login is disabled"""
|
login is disabled"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
channel = self._send_password_login("localuser", "localpass")
|
channel = self._send_password_login("localuser", "localpass")
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
mock_password_provider.check_auth.assert_not_called()
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
mock_password_provider.check_password.assert_not_called()
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
**legacy_providers_config(LegacyPasswordCustomAuthProvider),
|
||||||
|
"password_config": {"enabled": False},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
|
||||||
|
self.password_custom_auth_password_disabled_ui_auth_test_body()
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
|
@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_password_custom_auth_password_disabled_ui_auth(self):
|
def test_password_custom_auth_password_disabled_ui_auth(self):
|
||||||
|
self.password_custom_auth_password_disabled_ui_auth_test_body()
|
||||||
|
|
||||||
|
def password_custom_auth_password_disabled_ui_auth_test_body(self):
|
||||||
"""UI Auth with a custom auth provider which implements password, but password
|
"""UI Auth with a custom auth provider which implements password, but password
|
||||||
login is disabled"""
|
login is disabled"""
|
||||||
# register the user and log in twice via the test login type to get two devices,
|
# register the user and log in twice via the test login type to get two devices,
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||||
"@localuser:test"
|
("@localuser:test", None)
|
||||||
)
|
)
|
||||||
channel = self._send_login("test.login_type", "localuser", test_field="")
|
channel = self._send_login("test.login_type", "localuser", test_field="")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"Password login has been disabled.", channel.json_body["error"]
|
"Password login has been disabled.", channel.json_body["error"]
|
||||||
)
|
)
|
||||||
mock_password_provider.check_auth.assert_not_called()
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
mock_password_provider.check_password.assert_not_called()
|
||||||
mock_password_provider.reset_mock()
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
# successful auth
|
# successful auth
|
||||||
|
@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
mock_password_provider.check_auth.assert_called_once_with(
|
mock_password_provider.check_auth.assert_called_once_with(
|
||||||
"localuser", "test.login_type", {"test_field": "x"}
|
"localuser", "test.login_type", {"test_field": "x"}
|
||||||
)
|
)
|
||||||
|
mock_password_provider.check_password.assert_not_called()
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
**legacy_providers_config(LegacyCustomAuthProvider),
|
||||||
|
"password_config": {"localdb_enabled": False},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_custom_auth_no_local_user_fallback_legacy(self):
|
||||||
|
self.custom_auth_no_local_user_fallback_test_body()
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
|
@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_custom_auth_no_local_user_fallback(self):
|
def test_custom_auth_no_local_user_fallback(self):
|
||||||
|
self.custom_auth_no_local_user_fallback_test_body()
|
||||||
|
|
||||||
|
def custom_auth_no_local_user_fallback_test_body(self):
|
||||||
"""Test login with a custom auth provider where the local db is disabled"""
|
"""Test login with a custom auth provider where the local db is disabled"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue