UIA: offer only available auth flows
During user-interactive auth, do not offer password auth to users with no password, nor SSO auth to users with no SSO. Fixes #7559.pull/8858/head
parent
76469898ee
commit
0bac276890
|
@ -193,9 +193,7 @@ class AuthHandler(BaseHandler):
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
self._password_enabled = hs.config.password_enabled
|
self._password_enabled = hs.config.password_enabled
|
||||||
self._sso_enabled = (
|
self._password_localdb_enabled = hs.config.password_localdb_enabled
|
||||||
hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
|
|
||||||
)
|
|
||||||
|
|
||||||
# we keep this as a list despite the O(N^2) implication so that we can
|
# we keep this as a list despite the O(N^2) implication so that we can
|
||||||
# keep PASSWORD first and avoid confusing clients which pick the first
|
# keep PASSWORD first and avoid confusing clients which pick the first
|
||||||
|
@ -205,7 +203,7 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
# start out by assuming PASSWORD is enabled; we will remove it later if not.
|
# start out by assuming PASSWORD is enabled; we will remove it later if not.
|
||||||
login_types = []
|
login_types = []
|
||||||
if hs.config.password_localdb_enabled:
|
if self._password_localdb_enabled:
|
||||||
login_types.append(LoginType.PASSWORD)
|
login_types.append(LoginType.PASSWORD)
|
||||||
|
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
|
@ -219,14 +217,6 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
self._supported_login_types = login_types
|
self._supported_login_types = login_types
|
||||||
|
|
||||||
# Login types and UI Auth types have a heavy overlap, but are not
|
|
||||||
# necessarily identical. Login types have SSO (and other login types)
|
|
||||||
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
|
|
||||||
ui_auth_types = login_types.copy()
|
|
||||||
if self._sso_enabled:
|
|
||||||
ui_auth_types.append(LoginType.SSO)
|
|
||||||
self._supported_ui_auth_types = ui_auth_types
|
|
||||||
|
|
||||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||||
# as per `rc_login.failed_attempts`.
|
# as per `rc_login.failed_attempts`.
|
||||||
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
||||||
|
@ -339,7 +329,10 @@ class AuthHandler(BaseHandler):
|
||||||
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
|
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
|
||||||
|
|
||||||
# build a list of supported flows
|
# build a list of supported flows
|
||||||
flows = [[login_type] for login_type in self._supported_ui_auth_types]
|
supported_ui_auth_types = await self._get_available_ui_auth_types(
|
||||||
|
requester.user
|
||||||
|
)
|
||||||
|
flows = [[login_type] for login_type in supported_ui_auth_types]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result, params, session_id = await self.check_ui_auth(
|
result, params, session_id = await self.check_ui_auth(
|
||||||
|
@ -351,7 +344,7 @@ class AuthHandler(BaseHandler):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# find the completed login type
|
# find the completed login type
|
||||||
for login_type in self._supported_ui_auth_types:
|
for login_type in supported_ui_auth_types:
|
||||||
if login_type not in result:
|
if login_type not in result:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -367,6 +360,41 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
return params, session_id
|
return params, session_id
|
||||||
|
|
||||||
|
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
|
||||||
|
"""Get a list of the authentication types this user can use
|
||||||
|
"""
|
||||||
|
|
||||||
|
ui_auth_types = set()
|
||||||
|
|
||||||
|
# if the HS supports password auth, and the user has a non-null password, we
|
||||||
|
# support password auth
|
||||||
|
if self._password_localdb_enabled and self._password_enabled:
|
||||||
|
lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
|
||||||
|
if lookupres:
|
||||||
|
_, password_hash = lookupres
|
||||||
|
if password_hash:
|
||||||
|
ui_auth_types.add(LoginType.PASSWORD)
|
||||||
|
|
||||||
|
# also allow auth from password providers
|
||||||
|
for provider in self.password_providers:
|
||||||
|
for t in provider.get_supported_login_types().keys():
|
||||||
|
if t == LoginType.PASSWORD and not self._password_enabled:
|
||||||
|
continue
|
||||||
|
ui_auth_types.add(t)
|
||||||
|
|
||||||
|
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
|
||||||
|
# from sso to mxid.
|
||||||
|
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
|
||||||
|
if await self.store.get_external_ids_by_user(user.to_string()):
|
||||||
|
ui_auth_types.add(LoginType.SSO)
|
||||||
|
|
||||||
|
# Our CAS impl does not (yet) correctly register users in user_external_ids,
|
||||||
|
# so always offer that if it's available.
|
||||||
|
if self.hs.config.cas.cas_enabled:
|
||||||
|
ui_auth_types.add(LoginType.SSO)
|
||||||
|
|
||||||
|
return ui_auth_types
|
||||||
|
|
||||||
def get_enabled_auth_types(self):
|
def get_enabled_auth_types(self):
|
||||||
"""Return the enabled user-interactive authentication types
|
"""Return the enabled user-interactive authentication types
|
||||||
|
|
||||||
|
@ -1029,7 +1057,7 @@ class AuthHandler(BaseHandler):
|
||||||
if result:
|
if result:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
|
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
|
||||||
known_login_type = True
|
known_login_type = True
|
||||||
|
|
||||||
# we've already checked that there is a (valid) password field
|
# we've already checked that there is a (valid) password field
|
||||||
|
|
|
@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
desc="get_user_by_external_id",
|
desc="get_user_by_external_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
|
||||||
|
"""Look up external ids for the given user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mxid: the MXID to be looked up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuples of (auth_provider, external_id)
|
||||||
|
"""
|
||||||
|
res = await self.db_pool.simple_select_list(
|
||||||
|
table="user_external_ids",
|
||||||
|
keyvalues={"user_id": mxid},
|
||||||
|
retcols=("auth_provider", "external_id"),
|
||||||
|
desc="get_external_ids_by_user",
|
||||||
|
)
|
||||||
|
return [(r["auth_provider"], r["external_id"]) for r in res]
|
||||||
|
|
||||||
async def count_all_users(self):
|
async def count_all_users(self):
|
||||||
"""Counts all users registered on the homeserver."""
|
"""Counts all users registered on the homeserver."""
|
||||||
|
|
||||||
|
@ -963,6 +980,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
|
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.db_pool.updates.register_background_index_update(
|
||||||
|
"user_external_ids_user_id_idx",
|
||||||
|
index_name="user_external_ids_user_id_idx",
|
||||||
|
table="user_external_ids",
|
||||||
|
columns=["user_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
|
||||||
async def _background_update_set_deactivated_flag(self, progress, batch_size):
|
async def _background_update_set_deactivated_flag(self, progress, batch_size):
|
||||||
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
||||||
for each of them.
|
for each of them.
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||||
|
(5825, 'user_external_ids_user_id_idx', '{}');
|
|
@ -2,7 +2,7 @@
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
# Copyright 2017 Vector Creations Ltd
|
# Copyright 2017 Vector Creations Ltd
|
||||||
# Copyright 2018-2019 New Vector Ltd
|
# Copyright 2018-2019 New Vector Ltd
|
||||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -17,17 +17,23 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
|
import urllib.parse
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from mock import patch
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.web.server import Site
|
from twisted.web.server import Site
|
||||||
|
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
from tests.server import FakeSite, make_request
|
from tests.server import FakeSite, make_request
|
||||||
|
from tests.test_utils import FakeResponse
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
|
@ -344,3 +350,111 @@ class RestHelper:
|
||||||
)
|
)
|
||||||
|
|
||||||
return channel.json_body
|
return channel.json_body
|
||||||
|
|
||||||
|
def login_via_oidc(self, remote_user_id: str) -> JsonDict:
|
||||||
|
"""Log in (as a new user) via OIDC
|
||||||
|
|
||||||
|
Returns the result of the final token login.
|
||||||
|
|
||||||
|
Requires that "oidc_config" in the homeserver config be set appropriately
|
||||||
|
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
|
||||||
|
"public_base_url".
|
||||||
|
|
||||||
|
Also requires the login servlet and the OIDC callback resource to be mounted at
|
||||||
|
the normal places.
|
||||||
|
"""
|
||||||
|
client_redirect_url = "https://x"
|
||||||
|
|
||||||
|
# first hit the redirect url (which will issue a cookie and state)
|
||||||
|
_, channel = make_request(
|
||||||
|
self.hs.get_reactor(),
|
||||||
|
self.site,
|
||||||
|
"GET",
|
||||||
|
"/login/sso/redirect?redirectUrl=" + client_redirect_url,
|
||||||
|
)
|
||||||
|
# that will redirect to the OIDC IdP, but we skip that and go straight
|
||||||
|
# back to synapse's OIDC callback resource. However, we do need the "state"
|
||||||
|
# param that synapse passes to the IdP via query params, and the cookie that
|
||||||
|
# synapse passes to the client.
|
||||||
|
assert channel.code == 302
|
||||||
|
oauth_uri = channel.headers.getRawHeaders("Location")[0]
|
||||||
|
params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
|
||||||
|
redirect_uri = "%s?%s" % (
|
||||||
|
urllib.parse.urlparse(params["redirect_uri"][0]).path,
|
||||||
|
urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
|
||||||
|
)
|
||||||
|
cookies = {}
|
||||||
|
for h in channel.headers.getRawHeaders("Set-Cookie"):
|
||||||
|
parts = h.split(";")
|
||||||
|
k, v = parts[0].split("=", maxsplit=1)
|
||||||
|
cookies[k] = v
|
||||||
|
|
||||||
|
# before we hit the callback uri, stub out some methods in the http client so
|
||||||
|
# that we don't have to handle full HTTPS requests.
|
||||||
|
|
||||||
|
# (expected url, json response) pairs, in the order we expect them.
|
||||||
|
expected_requests = [
|
||||||
|
# first we get a hit to the token endpoint, which we tell to return
|
||||||
|
# a dummy OIDC access token
|
||||||
|
("https://issuer.test/token", {"access_token": "TEST"}),
|
||||||
|
# and then one to the user_info endpoint, which returns our remote user id.
|
||||||
|
("https://issuer.test/userinfo", {"sub": remote_user_id}),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def mock_req(method: str, uri: str, data=None, headers=None):
|
||||||
|
(expected_uri, resp_obj) = expected_requests.pop(0)
|
||||||
|
assert uri == expected_uri
|
||||||
|
resp = FakeResponse(
|
||||||
|
code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
|
||||||
|
# now hit the callback URI with the right params and a made-up code
|
||||||
|
_, channel = make_request(
|
||||||
|
self.hs.get_reactor(),
|
||||||
|
self.site,
|
||||||
|
"GET",
|
||||||
|
redirect_uri,
|
||||||
|
custom_headers=[
|
||||||
|
("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# expect a confirmation page
|
||||||
|
assert channel.code == 200
|
||||||
|
|
||||||
|
# fish the matrix login token out of the body of the confirmation page
|
||||||
|
m = re.search(
|
||||||
|
'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
|
||||||
|
channel.result["body"].decode("utf-8"),
|
||||||
|
)
|
||||||
|
assert m
|
||||||
|
login_token = m.group(1)
|
||||||
|
|
||||||
|
# finally, submit the matrix login token to the login API, which gives us our
|
||||||
|
# matrix access token and device id.
|
||||||
|
_, channel = make_request(
|
||||||
|
self.hs.get_reactor(),
|
||||||
|
self.site,
|
||||||
|
"POST",
|
||||||
|
"/login",
|
||||||
|
content={"type": "m.login.token", "token": login_token},
|
||||||
|
)
|
||||||
|
assert channel.code == 200
|
||||||
|
return channel.json_body
|
||||||
|
|
||||||
|
|
||||||
|
# an 'oidc_config' suitable for login_with_oidc.
|
||||||
|
TEST_OIDC_CONFIG = {
|
||||||
|
"enabled": True,
|
||||||
|
"discover": False,
|
||||||
|
"issuer": "https://issuer.test",
|
||||||
|
"client_id": "test-client-id",
|
||||||
|
"client_secret": "test-client-secret",
|
||||||
|
"scopes": ["profile"],
|
||||||
|
"authorization_endpoint": "https://z",
|
||||||
|
"token_endpoint": "https://issuer.test/token",
|
||||||
|
"userinfo_endpoint": "https://issuer.test/userinfo",
|
||||||
|
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
|
||||||
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from twisted.internet.defer import succeed
|
from twisted.internet.defer import succeed
|
||||||
|
@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.rest.client.v1 import login
|
from synapse.rest.client.v1 import login
|
||||||
from synapse.rest.client.v2_alpha import auth, devices, register
|
from synapse.rest.client.v2_alpha import auth, devices, register
|
||||||
from synapse.types import JsonDict
|
from synapse.rest.oidc import OIDCResource
|
||||||
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
|
||||||
from tests.server import FakeChannel
|
from tests.server import FakeChannel
|
||||||
|
|
||||||
|
|
||||||
|
@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
register.register_servlets,
|
register.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def default_config(self):
|
||||||
|
config = super().default_config()
|
||||||
|
|
||||||
|
# we enable OIDC as a way of testing SSO flows
|
||||||
|
oidc_config = {}
|
||||||
|
oidc_config.update(TEST_OIDC_CONFIG)
|
||||||
|
oidc_config["allow_existing_users"] = True
|
||||||
|
|
||||||
|
config["oidc_config"] = oidc_config
|
||||||
|
config["public_baseurl"] = "https://synapse.test"
|
||||||
|
return config
|
||||||
|
|
||||||
|
def create_resource_dict(self):
|
||||||
|
resource_dict = super().create_resource_dict()
|
||||||
|
# mount the OIDC resource at /_synapse/oidc
|
||||||
|
resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
|
||||||
|
return resource_dict
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
self.user_pass = "pass"
|
self.user_pass = "pass"
|
||||||
self.user = self.register_user("test", self.user_pass)
|
self.user = self.register_user("test", self.user_pass)
|
||||||
self.user_tok = self.login("test", self.user_pass)
|
self.user_tok = self.login("test", self.user_pass)
|
||||||
|
|
||||||
def get_device_ids(self) -> List[str]:
|
def get_device_ids(self, access_token: str) -> List[str]:
|
||||||
# Get the list of devices so one can be deleted.
|
# Get the list of devices so one can be deleted.
|
||||||
request, channel = self.make_request(
|
_, channel = self.make_request("GET", "devices", access_token=access_token,)
|
||||||
"GET", "devices", access_token=self.user_tok,
|
self.assertEqual(channel.code, 200)
|
||||||
) # type: SynapseRequest, FakeChannel
|
|
||||||
|
|
||||||
# Get the ID of the device.
|
|
||||||
self.assertEqual(request.code, 200)
|
|
||||||
return [d["device_id"] for d in channel.json_body["devices"]]
|
return [d["device_id"] for d in channel.json_body["devices"]]
|
||||||
|
|
||||||
def delete_device(
|
def delete_device(
|
||||||
self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
|
self,
|
||||||
|
access_token: str,
|
||||||
|
device: str,
|
||||||
|
expected_response: int,
|
||||||
|
body: Union[bytes, JsonDict] = b"",
|
||||||
) -> FakeChannel:
|
) -> FakeChannel:
|
||||||
"""Delete an individual device."""
|
"""Delete an individual device."""
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"DELETE", "devices/" + device, body, access_token=self.user_tok
|
"DELETE", "devices/" + device, body, access_token=access_token,
|
||||||
) # type: SynapseRequest, FakeChannel
|
) # type: SynapseRequest, FakeChannel
|
||||||
|
|
||||||
# Ensure the response is sane.
|
# Ensure the response is sane.
|
||||||
|
@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
Test user interactive authentication outside of registration.
|
Test user interactive authentication outside of registration.
|
||||||
"""
|
"""
|
||||||
device_id = self.get_device_ids()[0]
|
device_id = self.get_device_ids(self.user_tok)[0]
|
||||||
|
|
||||||
# Attempt to delete this device.
|
# Attempt to delete this device.
|
||||||
# Returns a 401 as per the spec
|
# Returns a 401 as per the spec
|
||||||
channel = self.delete_device(device_id, 401)
|
channel = self.delete_device(self.user_tok, device_id, 401)
|
||||||
|
|
||||||
# Grab the session
|
# Grab the session
|
||||||
session = channel.json_body["session"]
|
session = channel.json_body["session"]
|
||||||
|
@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Make another request providing the UI auth flow.
|
# Make another request providing the UI auth flow.
|
||||||
self.delete_device(
|
self.delete_device(
|
||||||
|
self.user_tok,
|
||||||
device_id,
|
device_id,
|
||||||
200,
|
200,
|
||||||
{
|
{
|
||||||
|
@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
UIA - check that still works.
|
UIA - check that still works.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
device_id = self.get_device_ids()[0]
|
device_id = self.get_device_ids(self.user_tok)[0]
|
||||||
channel = self.delete_device(device_id, 401)
|
channel = self.delete_device(self.user_tok, device_id, 401)
|
||||||
session = channel.json_body["session"]
|
session = channel.json_body["session"]
|
||||||
|
|
||||||
# Make another request providing the UI auth flow.
|
# Make another request providing the UI auth flow.
|
||||||
self.delete_device(
|
self.delete_device(
|
||||||
|
self.user_tok,
|
||||||
device_id,
|
device_id,
|
||||||
200,
|
200,
|
||||||
{
|
{
|
||||||
|
@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
# Create a second login.
|
# Create a second login.
|
||||||
self.login("test", self.user_pass)
|
self.login("test", self.user_pass)
|
||||||
|
|
||||||
device_ids = self.get_device_ids()
|
device_ids = self.get_device_ids(self.user_tok)
|
||||||
self.assertEqual(len(device_ids), 2)
|
self.assertEqual(len(device_ids), 2)
|
||||||
|
|
||||||
# Attempt to delete the first device.
|
# Attempt to delete the first device.
|
||||||
|
@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
# Create a second login.
|
# Create a second login.
|
||||||
self.login("test", self.user_pass)
|
self.login("test", self.user_pass)
|
||||||
|
|
||||||
device_ids = self.get_device_ids()
|
device_ids = self.get_device_ids(self.user_tok)
|
||||||
self.assertEqual(len(device_ids), 2)
|
self.assertEqual(len(device_ids), 2)
|
||||||
|
|
||||||
# Attempt to delete the first device.
|
# Attempt to delete the first device.
|
||||||
# Returns a 401 as per the spec
|
# Returns a 401 as per the spec
|
||||||
channel = self.delete_device(device_ids[0], 401)
|
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||||
|
|
||||||
# Grab the session
|
# Grab the session
|
||||||
session = channel.json_body["session"]
|
session = channel.json_body["session"]
|
||||||
|
@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
# Make another request providing the UI auth flow, but try to delete the
|
# Make another request providing the UI auth flow, but try to delete the
|
||||||
# second device. This results in an error.
|
# second device. This results in an error.
|
||||||
self.delete_device(
|
self.delete_device(
|
||||||
|
self.user_tok,
|
||||||
device_ids[1],
|
device_ids[1],
|
||||||
403,
|
403,
|
||||||
{
|
{
|
||||||
|
@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_does_not_offer_password_for_sso_user(self):
|
||||||
|
login_resp = self.helper.login_via_oidc("username")
|
||||||
|
user_tok = login_resp["access_token"]
|
||||||
|
device_id = login_resp["device_id"]
|
||||||
|
|
||||||
|
# now call the device deletion API: we should get the option to auth with SSO
|
||||||
|
# and not password.
|
||||||
|
channel = self.delete_device(user_tok, device_id, 401)
|
||||||
|
|
||||||
|
flows = channel.json_body["flows"]
|
||||||
|
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
|
||||||
|
|
||||||
|
def test_does_not_offer_sso_for_password_user(self):
|
||||||
|
# now call the device deletion API: we should get the option to auth with SSO
|
||||||
|
# and not password.
|
||||||
|
device_ids = self.get_device_ids(self.user_tok)
|
||||||
|
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||||
|
|
||||||
|
flows = channel.json_body["flows"]
|
||||||
|
self.assertEqual(flows, [{"stages": ["m.login.password"]}])
|
||||||
|
|
||||||
|
def test_offers_both_flows_for_upgraded_user(self):
|
||||||
|
"""A user that had a password and then logged in with SSO should get both flows
|
||||||
|
"""
|
||||||
|
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
|
||||||
|
self.assertEqual(login_resp["user_id"], self.user)
|
||||||
|
|
||||||
|
device_ids = self.get_device_ids(self.user_tok)
|
||||||
|
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||||
|
|
||||||
|
flows = channel.json_body["flows"]
|
||||||
|
# we have no particular expectations of ordering here
|
||||||
|
self.assertIn({"stages": ["m.login.password"]}, flows)
|
||||||
|
self.assertIn({"stages": ["m.login.sso"]}, flows)
|
||||||
|
self.assertEqual(len(flows), 2)
|
||||||
|
|
|
@ -259,6 +259,7 @@ def make_request(
|
||||||
for k, v in custom_headers:
|
for k, v in custom_headers:
|
||||||
req.requestHeaders.addRawHeader(k, v)
|
req.requestHeaders.addRawHeader(k, v)
|
||||||
|
|
||||||
|
req.parseCookies()
|
||||||
req.requestReceived(method, path, b"1.1")
|
req.requestReceived(method, path, b"1.1")
|
||||||
|
|
||||||
if await_result:
|
if await_result:
|
||||||
|
|
Loading…
Reference in New Issue