UIA: offer only available auth flows

During user-interactive auth, do not offer password auth to users with no
password, nor SSO auth to users with no SSO.

Fixes #7559.
pull/8858/head
Richard van der Hoff 2020-12-01 00:15:36 +00:00
parent 76469898ee
commit 0bac276890
6 changed files with 278 additions and 33 deletions

View File

@ -193,9 +193,7 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled self._password_enabled = hs.config.password_enabled
self._sso_enabled = ( self._password_localdb_enabled = hs.config.password_localdb_enabled
hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
)
# we keep this as a list despite the O(N^2) implication so that we can # we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first # keep PASSWORD first and avoid confusing clients which pick the first
@ -205,7 +203,7 @@ class AuthHandler(BaseHandler):
# start out by assuming PASSWORD is enabled; we will remove it later if not. # start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = [] login_types = []
if hs.config.password_localdb_enabled: if self._password_localdb_enabled:
login_types.append(LoginType.PASSWORD) login_types.append(LoginType.PASSWORD)
for provider in self.password_providers: for provider in self.password_providers:
@ -219,14 +217,6 @@ class AuthHandler(BaseHandler):
self._supported_login_types = login_types self._supported_login_types = login_types
# Login types and UI Auth types have a heavy overlap, but are not
# necessarily identical. Login types have SSO (and other login types)
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
ui_auth_types = login_types.copy()
if self._sso_enabled:
ui_auth_types.append(LoginType.SSO)
self._supported_ui_auth_types = ui_auth_types
# Ratelimiter for failed auth during UIA. Uses same ratelimit config # Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`. # as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter( self._failed_uia_attempts_ratelimiter = Ratelimiter(
@ -339,7 +329,10 @@ class AuthHandler(BaseHandler):
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
# build a list of supported flows # build a list of supported flows
flows = [[login_type] for login_type in self._supported_ui_auth_types] supported_ui_auth_types = await self._get_available_ui_auth_types(
requester.user
)
flows = [[login_type] for login_type in supported_ui_auth_types]
try: try:
result, params, session_id = await self.check_ui_auth( result, params, session_id = await self.check_ui_auth(
@ -351,7 +344,7 @@ class AuthHandler(BaseHandler):
raise raise
# find the completed login type # find the completed login type
for login_type in self._supported_ui_auth_types: for login_type in supported_ui_auth_types:
if login_type not in result: if login_type not in result:
continue continue
@ -367,6 +360,41 @@ class AuthHandler(BaseHandler):
return params, session_id return params, session_id
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
"""Get a list of the authentication types this user can use
"""
ui_auth_types = set()
# if the HS supports password auth, and the user has a non-null password, we
# support password auth
if self._password_localdb_enabled and self._password_enabled:
lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
if lookupres:
_, password_hash = lookupres
if password_hash:
ui_auth_types.add(LoginType.PASSWORD)
# also allow auth from password providers
for provider in self.password_providers:
for t in provider.get_supported_login_types().keys():
if t == LoginType.PASSWORD and not self._password_enabled:
continue
ui_auth_types.add(t)
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
# from sso to mxid.
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
if await self.store.get_external_ids_by_user(user.to_string()):
ui_auth_types.add(LoginType.SSO)
# Our CAS impl does not (yet) correctly register users in user_external_ids,
# so always offer that if it's available.
if self.hs.config.cas.cas_enabled:
ui_auth_types.add(LoginType.SSO)
return ui_auth_types
def get_enabled_auth_types(self): def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types """Return the enabled user-interactive authentication types
@ -1029,7 +1057,7 @@ class AuthHandler(BaseHandler):
if result: if result:
return result return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True known_login_type = True
# we've already checked that there is a (valid) password field # we've already checked that there is a (valid) password field

View File

@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_external_id", desc="get_user_by_external_id",
) )
async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
"""Look up external ids for the given user
Args:
mxid: the MXID to be looked up
Returns:
Tuples of (auth_provider, external_id)
"""
res = await self.db_pool.simple_select_list(
table="user_external_ids",
keyvalues={"user_id": mxid},
retcols=("auth_provider", "external_id"),
desc="get_external_ids_by_user",
)
return [(r["auth_provider"], r["external_id"]) for r in res]
async def count_all_users(self): async def count_all_users(self):
"""Counts all users registered on the homeserver.""" """Counts all users registered on the homeserver."""
@ -963,6 +980,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
"users_set_deactivated_flag", self._background_update_set_deactivated_flag "users_set_deactivated_flag", self._background_update_set_deactivated_flag
) )
self.db_pool.updates.register_background_index_update(
"user_external_ids_user_id_idx",
index_name="user_external_ids_user_id_idx",
table="user_external_ids",
columns=["user_id"],
unique=False,
)
async def _background_update_set_deactivated_flag(self, progress, batch_size): async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them. for each of them.

View File

@ -0,0 +1,17 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5825, 'user_external_ids_user_id_idx', '{}');

View File

@ -2,7 +2,7 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright 2018-2019 New Vector Ltd # Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright 2019-2020 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,17 +17,23 @@
# limitations under the License. # limitations under the License.
import json import json
import re
import time import time
import urllib.parse
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from mock import patch
import attr import attr
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Site from twisted.web.server import Site
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.types import JsonDict
from tests.server import FakeSite, make_request from tests.server import FakeSite, make_request
from tests.test_utils import FakeResponse
@attr.s @attr.s
@ -344,3 +350,111 @@ class RestHelper:
) )
return channel.json_body return channel.json_body
def login_via_oidc(self, remote_user_id: str) -> JsonDict:
"""Log in (as a new user) via OIDC
Returns the result of the final token login.
Requires that "oidc_config" in the homeserver config be set appropriately
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
"public_base_url".
Also requires the login servlet and the OIDC callback resource to be mounted at
the normal places.
"""
client_redirect_url = "https://x"
# first hit the redirect url (which will issue a cookie and state)
_, channel = make_request(
self.hs.get_reactor(),
self.site,
"GET",
"/login/sso/redirect?redirectUrl=" + client_redirect_url,
)
# that will redirect to the OIDC IdP, but we skip that and go straight
# back to synapse's OIDC callback resource. However, we do need the "state"
# param that synapse passes to the IdP via query params, and the cookie that
# synapse passes to the client.
assert channel.code == 302
oauth_uri = channel.headers.getRawHeaders("Location")[0]
params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
redirect_uri = "%s?%s" % (
urllib.parse.urlparse(params["redirect_uri"][0]).path,
urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
)
cookies = {}
for h in channel.headers.getRawHeaders("Set-Cookie"):
parts = h.split(";")
k, v = parts[0].split("=", maxsplit=1)
cookies[k] = v
# before we hit the callback uri, stub out some methods in the http client so
# that we don't have to handle full HTTPS requests.
# (expected url, json response) pairs, in the order we expect them.
expected_requests = [
# first we get a hit to the token endpoint, which we tell to return
# a dummy OIDC access token
("https://issuer.test/token", {"access_token": "TEST"}),
# and then one to the user_info endpoint, which returns our remote user id.
("https://issuer.test/userinfo", {"sub": remote_user_id}),
]
async def mock_req(method: str, uri: str, data=None, headers=None):
(expected_uri, resp_obj) = expected_requests.pop(0)
assert uri == expected_uri
resp = FakeResponse(
code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
)
return resp
with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
# now hit the callback URI with the right params and a made-up code
_, channel = make_request(
self.hs.get_reactor(),
self.site,
"GET",
redirect_uri,
custom_headers=[
("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
],
)
# expect a confirmation page
assert channel.code == 200
# fish the matrix login token out of the body of the confirmation page
m = re.search(
'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
channel.result["body"].decode("utf-8"),
)
assert m
login_token = m.group(1)
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token and device id.
_, channel = make_request(
self.hs.get_reactor(),
self.site,
"POST",
"/login",
content={"type": "m.login.token", "token": login_token},
)
assert channel.code == 200
return channel.json_body
# an 'oidc_config' suitable for login_with_oidc.
TEST_OIDC_CONFIG = {
"enabled": True,
"discover": False,
"issuer": "https://issuer.test",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"scopes": ["profile"],
"authorization_endpoint": "https://z",
"token_endpoint": "https://issuer.test/token",
"userinfo_endpoint": "https://issuer.test/userinfo",
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
}

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Union from typing import List, Union
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.client.v1 import login from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import auth, devices, register from synapse.rest.client.v2_alpha import auth, devices, register
from synapse.types import JsonDict from synapse.rest.oidc import OIDCResource
from synapse.types import JsonDict, UserID
from tests import unittest from tests import unittest
from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel from tests.server import FakeChannel
@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets, register.register_servlets,
] ]
def default_config(self):
config = super().default_config()
# we enable OIDC as a way of testing SSO flows
oidc_config = {}
oidc_config.update(TEST_OIDC_CONFIG)
oidc_config["allow_existing_users"] = True
config["oidc_config"] = oidc_config
config["public_baseurl"] = "https://synapse.test"
return config
def create_resource_dict(self):
resource_dict = super().create_resource_dict()
# mount the OIDC resource at /_synapse/oidc
resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
return resource_dict
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.user_pass = "pass" self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass) self.user = self.register_user("test", self.user_pass)
self.user_tok = self.login("test", self.user_pass) self.user_tok = self.login("test", self.user_pass)
def get_device_ids(self) -> List[str]: def get_device_ids(self, access_token: str) -> List[str]:
# Get the list of devices so one can be deleted. # Get the list of devices so one can be deleted.
request, channel = self.make_request( _, channel = self.make_request("GET", "devices", access_token=access_token,)
"GET", "devices", access_token=self.user_tok, self.assertEqual(channel.code, 200)
) # type: SynapseRequest, FakeChannel
# Get the ID of the device.
self.assertEqual(request.code, 200)
return [d["device_id"] for d in channel.json_body["devices"]] return [d["device_id"] for d in channel.json_body["devices"]]
def delete_device( def delete_device(
self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b"" self,
access_token: str,
device: str,
expected_response: int,
body: Union[bytes, JsonDict] = b"",
) -> FakeChannel: ) -> FakeChannel:
"""Delete an individual device.""" """Delete an individual device."""
request, channel = self.make_request( request, channel = self.make_request(
"DELETE", "devices/" + device, body, access_token=self.user_tok "DELETE", "devices/" + device, body, access_token=access_token,
) # type: SynapseRequest, FakeChannel ) # type: SynapseRequest, FakeChannel
# Ensure the response is sane. # Ensure the response is sane.
@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
""" """
Test user interactive authentication outside of registration. Test user interactive authentication outside of registration.
""" """
device_id = self.get_device_ids()[0] device_id = self.get_device_ids(self.user_tok)[0]
# Attempt to delete this device. # Attempt to delete this device.
# Returns a 401 as per the spec # Returns a 401 as per the spec
channel = self.delete_device(device_id, 401) channel = self.delete_device(self.user_tok, device_id, 401)
# Grab the session # Grab the session
session = channel.json_body["session"] session = channel.json_body["session"]
@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow. # Make another request providing the UI auth flow.
self.delete_device( self.delete_device(
self.user_tok,
device_id, device_id,
200, 200,
{ {
@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
UIA - check that still works. UIA - check that still works.
""" """
device_id = self.get_device_ids()[0] device_id = self.get_device_ids(self.user_tok)[0]
channel = self.delete_device(device_id, 401) channel = self.delete_device(self.user_tok, device_id, 401)
session = channel.json_body["session"] session = channel.json_body["session"]
# Make another request providing the UI auth flow. # Make another request providing the UI auth flow.
self.delete_device( self.delete_device(
self.user_tok,
device_id, device_id,
200, 200,
{ {
@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Create a second login. # Create a second login.
self.login("test", self.user_pass) self.login("test", self.user_pass)
device_ids = self.get_device_ids() device_ids = self.get_device_ids(self.user_tok)
self.assertEqual(len(device_ids), 2) self.assertEqual(len(device_ids), 2)
# Attempt to delete the first device. # Attempt to delete the first device.
@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Create a second login. # Create a second login.
self.login("test", self.user_pass) self.login("test", self.user_pass)
device_ids = self.get_device_ids() device_ids = self.get_device_ids(self.user_tok)
self.assertEqual(len(device_ids), 2) self.assertEqual(len(device_ids), 2)
# Attempt to delete the first device. # Attempt to delete the first device.
# Returns a 401 as per the spec # Returns a 401 as per the spec
channel = self.delete_device(device_ids[0], 401) channel = self.delete_device(self.user_tok, device_ids[0], 401)
# Grab the session # Grab the session
session = channel.json_body["session"] session = channel.json_body["session"]
@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow, but try to delete the # Make another request providing the UI auth flow, but try to delete the
# second device. This results in an error. # second device. This results in an error.
self.delete_device( self.delete_device(
self.user_tok,
device_ids[1], device_ids[1],
403, 403,
{ {
@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
}, },
}, },
) )
def test_does_not_offer_password_for_sso_user(self):
login_resp = self.helper.login_via_oidc("username")
user_tok = login_resp["access_token"]
device_id = login_resp["device_id"]
# now call the device deletion API: we should get the option to auth with SSO
# and not password.
channel = self.delete_device(user_tok, device_id, 401)
flows = channel.json_body["flows"]
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
def test_does_not_offer_sso_for_password_user(self):
# now call the device deletion API: we should get the option to auth with SSO
# and not password.
device_ids = self.get_device_ids(self.user_tok)
channel = self.delete_device(self.user_tok, device_ids[0], 401)
flows = channel.json_body["flows"]
self.assertEqual(flows, [{"stages": ["m.login.password"]}])
def test_offers_both_flows_for_upgraded_user(self):
"""A user that had a password and then logged in with SSO should get both flows
"""
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
self.assertEqual(login_resp["user_id"], self.user)
device_ids = self.get_device_ids(self.user_tok)
channel = self.delete_device(self.user_tok, device_ids[0], 401)
flows = channel.json_body["flows"]
# we have no particular expectations of ordering here
self.assertIn({"stages": ["m.login.password"]}, flows)
self.assertIn({"stages": ["m.login.sso"]}, flows)
self.assertEqual(len(flows), 2)

View File

@ -259,6 +259,7 @@ def make_request(
for k, v in custom_headers: for k, v in custom_headers:
req.requestHeaders.addRawHeader(k, v) req.requestHeaders.addRawHeader(k, v)
req.parseCookies()
req.requestReceived(method, path, b"1.1") req.requestReceived(method, path, b"1.1")
if await_result: if await_result: