Consolidate logic to check for deactivated users. (#15634)

This moves the deactivated user check to the method which
all login types call.

Additionally updates the application service tests to be more
realistic by removing invalid tests and fixing server names.
madlittlemods/15657-export-synapse-version-as-metric
Patrick Cloke 2023-05-23 10:35:43 -04:00 committed by GitHub
parent 1df0221bda
commit 7c9b91790c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 55 additions and 67 deletions

1
changelog.d/15634.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where deactivated users were able to login in uncommon situations.

View File

@ -46,6 +46,9 @@ instead.
If the authentication is unsuccessful, the module must return `None`. If the authentication is unsuccessful, the module must return `None`.
Note that the user is not automatically registered, the `register_user(..)` method of
the [module API](writing_a_module.html) can be used to lazily create users.
If multiple modules register an auth checker for the same login type but with different If multiple modules register an auth checker for the same login type but with different
fields, Synapse will refuse to start. fields, Synapse will refuse to start.

View File

@ -86,6 +86,7 @@ class ApplicationService:
url.rstrip("/") if isinstance(url, str) else None url.rstrip("/") if isinstance(url, str) else None
) # url must not end with a slash ) # url must not end with a slash
self.hs_token = hs_token self.hs_token = hs_token
# The full Matrix ID for this application service's sender.
self.sender = sender self.sender = sender
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
@ -212,7 +213,7 @@ class ApplicationService:
True if the application service is interested in the user, False if not. True if the application service is interested in the user, False if not.
""" """
return ( return (
# User is the appservice's sender_localpart user # User is the appservice's configured sender_localpart user
user_id == self.sender user_id == self.sender
# User is in the appservice's user namespace # User is in the appservice's user namespace
or self.is_user_in_namespace(user_id) or self.is_user_in_namespace(user_id)

View File

@ -52,7 +52,6 @@ from synapse.api.errors import (
NotFoundError, NotFoundError,
StoreError, StoreError,
SynapseError, SynapseError,
UserDeactivatedError,
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.ui_auth import ( from synapse.handlers.ui_auth import (
@ -1419,12 +1418,6 @@ class AuthHandler:
return None return None
(user_id, password_hash) = lookupres (user_id, password_hash) = lookupres
# If the password hash is None, the account has likely been deactivated
if not password_hash:
deactivated = await self.store.get_user_deactivated_status(user_id)
if deactivated:
raise UserDeactivatedError("This account has been deactivated")
result = await self.validate_hash(password, password_hash) result = await self.validate_hash(password, password_hash)
if not result: if not result:
logger.warning("Failed password login for user %s", user_id) logger.warning("Failed password login for user %s", user_id)
@ -1749,8 +1742,11 @@ class AuthHandler:
registered. registered.
auth_provider_session_id: The session ID from the SSO IdP received during login. auth_provider_session_id: The session ID from the SSO IdP received during login.
""" """
# If the account has been deactivated, do not proceed with the login # If the account has been deactivated, do not proceed with the login.
# flow. #
# This gets checked again when the token is submitted but this lets us
# provide an HTML error page to the user (instead of issuing a token and
# having it error later).
deactivated = await self.store.get_user_deactivated_status(registered_user_id) deactivated = await self.store.get_user_deactivated_status(registered_user_id)
if deactivated: if deactivated:
respond_with_html(request, 403, self._sso_account_deactivated_template) respond_with_html(request, 403, self._sso_account_deactivated_template)

View File

@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
from authlib.jose import JsonWebToken, JWTClaims from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
from synapse.api.errors import Codes, LoginError, StoreError, UserDeactivatedError from synapse.api.errors import Codes, LoginError
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
@ -26,7 +26,6 @@ if TYPE_CHECKING:
class JwtHandler: class JwtHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self._main_store = hs.get_datastores().main
self.jwt_secret = hs.config.jwt.jwt_secret self.jwt_secret = hs.config.jwt.jwt_secret
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
@ -34,7 +33,7 @@ class JwtHandler:
self.jwt_issuer = hs.config.jwt.jwt_issuer self.jwt_issuer = hs.config.jwt.jwt_issuer
self.jwt_audiences = hs.config.jwt.jwt_audiences self.jwt_audiences = hs.config.jwt.jwt_audiences
async def validate_login(self, login_submission: JsonDict) -> str: def validate_login(self, login_submission: JsonDict) -> str:
""" """
Authenticates the user for the /login API Authenticates the user for the /login API
@ -103,16 +102,4 @@ class JwtHandler:
if user is None: if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
user_id = UserID(user, self.hs.hostname).to_string() return UserID(user, self.hs.hostname).to_string()
# If the account has been deactivated, do not proceed with the login
# flow.
try:
deactivated = await self._main_store.get_user_deactivated_status(user_id)
except StoreError:
# JWT lazily creates users, so they may not exist in the database yet.
deactivated = False
if deactivated:
raise UserDeactivatedError("This account has been deactivated")
return user_id

View File

@ -35,6 +35,7 @@ from synapse.api.errors import (
LoginError, LoginError,
NotApprovedError, NotApprovedError,
SynapseError, SynapseError,
UserDeactivatedError,
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.api.urls import CLIENT_API_PREFIX from synapse.api.urls import CLIENT_API_PREFIX
@ -84,6 +85,7 @@ class LoginRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self._main_store = hs.get_datastores().main
# JWT configuration variables. # JWT configuration variables.
self.jwt_enabled = hs.config.jwt.jwt_enabled self.jwt_enabled = hs.config.jwt.jwt_enabled
@ -112,13 +114,13 @@ class LoginRestServlet(RestServlet):
self._well_known_builder = WellKnownBuilder(hs) self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter( self._address_ratelimiter = Ratelimiter(
store=hs.get_datastores().main, store=self._main_store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count,
) )
self._account_ratelimiter = Ratelimiter( self._account_ratelimiter = Ratelimiter(
store=hs.get_datastores().main, store=self._main_store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count,
@ -280,6 +282,9 @@ class LoginRestServlet(RestServlet):
login_submission, login_submission,
ratelimit=appservice.is_rate_limited(), ratelimit=appservice.is_rate_limited(),
should_issue_refresh_token=should_issue_refresh_token, should_issue_refresh_token=should_issue_refresh_token,
# The user represented by an appservice's configured sender_localpart
# is not actually created in Synapse.
should_check_deactivated=qualified_user_id != appservice.sender,
) )
async def _do_other_login( async def _do_other_login(
@ -326,6 +331,7 @@ class LoginRestServlet(RestServlet):
auth_provider_id: Optional[str] = None, auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False, should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None, auth_provider_session_id: Optional[str] = None,
should_check_deactivated: bool = True,
) -> LoginResponse: ) -> LoginResponse:
"""Called when we've successfully authed the user and now need to """Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on actually login them in (e.g. create devices). This gets called on
@ -345,6 +351,11 @@ class LoginRestServlet(RestServlet):
should_issue_refresh_token: True if this login should issue should_issue_refresh_token: True if this login should issue
a refresh token alongside the access token. a refresh token alongside the access token.
auth_provider_session_id: The session ID got during login from the SSO IdP. auth_provider_session_id: The session ID got during login from the SSO IdP.
should_check_deactivated: True if the user should be checked for
deactivation status before logging in.
This exists purely for appservice's configured sender_localpart
which doesn't have an associated user in the database.
Returns: Returns:
Dictionary of account information after successful login. Dictionary of account information after successful login.
@ -364,6 +375,12 @@ class LoginRestServlet(RestServlet):
) )
user_id = canonical_uid user_id = canonical_uid
# If the account has been deactivated, do not proceed with the login.
if should_check_deactivated:
deactivated = await self._main_store.get_user_deactivated_status(user_id)
if deactivated:
raise UserDeactivatedError("This account has been deactivated")
device_id = login_submission.get("device_id") device_id = login_submission.get("device_id")
# If device_id is present, check that device_id is not longer than a reasonable 512 characters # If device_id is present, check that device_id is not longer than a reasonable 512 characters
@ -458,7 +475,7 @@ class LoginRestServlet(RestServlet):
Returns: Returns:
The body of the JSON response. The body of the JSON response.
""" """
user_id = await self.hs.get_jwt_handler().validate_login(login_submission) user_id = self.hs.get_jwt_handler().validate_login(login_submission)
return await self._complete_login( return await self._complete_login(
user_id, user_id,
login_submission, login_submission,

View File

@ -18,13 +18,17 @@ from http import HTTPStatus
from typing import Any, Dict, List, Optional, Type, Union from typing import Any, Dict, List, Optional, Type, Union
from unittest.mock import Mock from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse import synapse
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.handlers.account import AccountHandler from synapse.handlers.account import AccountHandler
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register from synapse.rest.client import account, devices, login, logout, register
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeChannel from tests.server import FakeChannel
@ -162,10 +166,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
CALLBACK_USERNAME = "get_username_for_registration" CALLBACK_USERNAME = "get_username_for_registration"
CALLBACK_DISPLAYNAME = "get_displayname_for_registration" CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
def setUp(self) -> None: def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
# we use a global mock device, so make sure we are starting with a clean slate # we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
super().setUp()
# The mock password provider doesn't register the users, so ensure they
# are registered first.
self.register_user("u", "not-the-tested-password")
self.register_user("user", "not-the-tested-password")
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self) -> None: def test_password_only_auth_progiver_login_legacy(self) -> None:
@ -185,22 +195,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# login with mxid should work too # login with mxid should work too
channel = self._send_password_login("@u:bz", "p") channel = self._send_password_login("@u:test", "p")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:bz", channel.json_body["user_id"]) self.assertEqual("@u:test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p") mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# try a weird username / pass. Honestly it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with(
"@ USER🙂NAME :test", " pASS😢word "
)
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_provider_ui_auth_legacy(self) -> None: def test_password_only_auth_provider_ui_auth_legacy(self) -> None:
self.password_only_auth_provider_ui_auth_test_body() self.password_only_auth_provider_ui_auth_test_body()
@ -208,10 +208,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
def password_only_auth_provider_ui_auth_test_body(self) -> None: def password_only_auth_provider_ui_auth_test_body(self) -> None:
"""UI Auth should delegate correctly to the password provider""" """UI Auth should delegate correctly to the password provider"""
# create the user, otherwise access doesn't work
module_api = self.hs.get_module_api()
self.get_success(module_api.register_user("u"))
# log in twice, to get two devices # log in twice, to get two devices
mock_password_provider.check_password.return_value = make_awaitable(True) mock_password_provider.check_password.return_value = make_awaitable(True)
tok1 = self.login("u", "p") tok1 = self.login("u", "p")
@ -401,29 +397,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = make_awaitable( mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None) ("@user:test", None)
) )
channel = self._send_login("test.login_type", "u", test_field="y") channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"]) self.assertEqual("@user:test", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with( mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"} "u", "test.login_type", {"test_field": "y"}
) )
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# try a weird username. Again, it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
mock_password_provider.check_auth.return_value = make_awaitable(
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
)
@override_config(legacy_providers_config(LegacyCustomAuthProvider)) @override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_ui_auth_legacy(self) -> None: def test_custom_auth_provider_ui_auth_legacy(self) -> None:
self.custom_auth_provider_ui_auth_test_body() self.custom_auth_provider_ui_auth_test_body()
@ -465,7 +448,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# right params, but authing as the wrong user # right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = make_awaitable( mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None) ("@user:test", None)
) )
body["auth"]["test_field"] = "foo" body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body) channel = self._delete_device(tok1, "dev2", body)
@ -498,11 +481,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
callback = Mock(return_value=make_awaitable(None)) callback = Mock(return_value=make_awaitable(None))
mock_password_provider.check_auth.return_value = make_awaitable( mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", callback) ("@user:test", callback)
) )
channel = self._send_login("test.login_type", "u", test_field="y") channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"]) self.assertEqual("@user:test", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with( mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"} "u", "test.login_type", {"test_field": "y"}
) )
@ -512,7 +495,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
call_args, call_kwargs = callback.call_args call_args, call_kwargs = callback.call_args
# should be one positional arg # should be one positional arg
self.assertEqual(len(call_args), 1) self.assertEqual(len(call_args), 1)
self.assertEqual(call_args[0]["user_id"], "@user:bz") self.assertEqual(call_args[0]["user_id"], "@user:test")
for p in ["user_id", "access_token", "device_id", "home_server"]: for p in ["user_id", "access_token", "device_id", "home_server"]:
self.assertIn(p, call_args[0]) self.assertIn(p, call_args[0])