diff --git a/src/components/structures/MatrixChat.tsx b/src/components/structures/MatrixChat.tsx index 7ce7ae2c17..a91d6fb928 100644 --- a/src/components/structures/MatrixChat.tsx +++ b/src/components/structures/MatrixChat.tsx @@ -393,7 +393,8 @@ export default class MatrixChat extends React.PureComponent { return; } - if (firstScreen === "login" || firstScreen === "register" || firstScreen === "forgot_password") { + // If the first screen is an auth screen, we don't want to wait for login. + if (firstScreen !== null && AUTH_SCREENS.includes(firstScreen)) { this.showScreenAfterLogin(); } } diff --git a/test/components/structures/MatrixChat-test.tsx b/test/components/structures/MatrixChat-test.tsx index c01b0025ab..ac8ad257f7 100644 --- a/test/components/structures/MatrixChat-test.tsx +++ b/test/components/structures/MatrixChat-test.tsx @@ -41,6 +41,7 @@ import { MockClientWithEventEmitter, mockPlatformPeg, resetJsDomAfterEach, + unmockClientPeg, } from "../../test-utils"; import * as leaveRoomUtils from "../../../src/utils/leave-behaviour"; import * as voiceBroadcastUtils from "../../../src/voice-broadcast/utils/cleanUpBroadcasts"; @@ -51,6 +52,7 @@ import { PosthogAnalytics } from "../../../src/PosthogAnalytics"; import PlatformPeg from "../../../src/PlatformPeg"; import EventIndexPeg from "../../../src/indexing/EventIndexPeg"; import * as Lifecycle from "../../../src/Lifecycle"; +import { SSO_HOMESERVER_URL_KEY, SSO_ID_SERVER_URL_KEY } from "../../../src/BasePlatform"; jest.mock("matrix-js-sdk/src/oidc/authorize", () => ({ completeAuthorizationCodeGrant: jest.fn(), @@ -69,6 +71,7 @@ describe("", () => { setCanResetTimelineCallback: jest.fn(), isInitialSyncComplete: jest.fn(), getSyncState: jest.fn(), + getSsoLoginUrl: jest.fn(), getSyncStateData: jest.fn().mockReturnValue(null), getThirdpartyProtocols: jest.fn().mockResolvedValue({}), getClientWellKnown: jest.fn().mockReturnValue({}), @@ -1107,6 +1110,64 @@ describe("", () => { }); }); + describe("automatic SSO selection", () => { + let ssoClient: ReturnType; + let hrefSetter: jest.Mock; + beforeEach(() => { + ssoClient = getMockClientWithEventEmitter({ + ...getMockClientMethods(), + getHomeserverUrl: jest.fn().mockReturnValue("matrix.example.com"), + getIdentityServerUrl: jest.fn().mockReturnValue("ident.example.com"), + getSsoLoginUrl: jest.fn().mockReturnValue("http://my-sso-url"), + }); + // this is used to create a temporary client to cleanup after logout + jest.spyOn(MatrixJs, "createClient").mockClear().mockReturnValue(ssoClient); + mockPlatformPeg(); + // Ensure we don't have a client peg as we aren't logged in. + unmockClientPeg(); + + hrefSetter = jest.fn(); + const originalHref = window.location.href.toString(); + Object.defineProperty(window, "location", { + value: { + get href() { + return originalHref; + }, + set href(href) { + hrefSetter(href); + }, + }, + writable: true, + }); + }); + + it("should automatically setup and redirect to SSO login", async () => { + getComponent({ + initialScreenAfterLogin: { + screen: "start_sso", + }, + }); + await flushPromises(); + expect(ssoClient.getSsoLoginUrl).toHaveBeenCalledWith("http://localhost/", "sso", undefined, undefined); + expect(window.localStorage.getItem(SSO_HOMESERVER_URL_KEY)).toEqual("matrix.example.com"); + expect(window.localStorage.getItem(SSO_ID_SERVER_URL_KEY)).toEqual("ident.example.com"); + expect(hrefSetter).toHaveBeenCalledWith("http://my-sso-url"); + }); + + it("should automatically setup and redirect to CAS login", async () => { + getComponent({ + initialScreenAfterLogin: { + screen: "start_cas", + }, + }); + await flushPromises(); + expect(ssoClient.getSsoLoginUrl).toHaveBeenCalledWith("http://localhost/", "cas", undefined, undefined); + expect(window.localStorage.getItem(SSO_HOMESERVER_URL_KEY)).toEqual("matrix.example.com"); + expect(window.localStorage.getItem(SSO_ID_SERVER_URL_KEY)).toEqual("ident.example.com"); + expect(hrefSetter).toHaveBeenCalledWith("http://my-sso-url"); + }); + }); + describe("Multi-tab lockout", () => { afterEach(() => { Lifecycle.setSessionLockNotStolen();