diff --git a/src/vector/app.tsx b/src/vector/app.tsx index 05f1cd5b85..8fe4a7eb1a 100644 --- a/src/vector/app.tsx +++ b/src/vector/app.tsx @@ -38,7 +38,7 @@ import { QueryDict, encodeParams } from "matrix-js-sdk/src/utils"; import { parseQs } from "./url_utils"; import VectorBasePlatform from "./platform/VectorBasePlatform"; -import { getScreenFromLocation, init as initRouting, onNewScreen } from "./routing"; +import { getInitialScreenAfterLogin, getScreenFromLocation, init as initRouting, onNewScreen } from "./routing"; // add React and ReactPerf to the global namespace, to make them easier to access via the console // this incidentally means we can forget our React imports in JSX files without penalty. @@ -133,6 +133,8 @@ export async function loadApp(fragParams: {}): Promise { const defaultDeviceName = snakedConfig.get("default_device_display_name") ?? platform?.getDefaultDeviceDisplayName(); + const initialScreenAfterLogin = getInitialScreenAfterLogin(window.location); + return ( { startingFragmentQueryParams={fragParams} enableGuest={!config.disable_guests} onTokenLoginCompleted={onTokenLoginCompleted} - initialScreenAfterLogin={getScreenFromLocation(window.location)} + initialScreenAfterLogin={initialScreenAfterLogin} defaultDeviceDisplayName={defaultDeviceName} /> ); diff --git a/src/vector/routing.ts b/src/vector/routing.ts index 04d455f51f..2420ee2fe7 100644 --- a/src/vector/routing.ts +++ b/src/vector/routing.ts @@ -76,3 +76,36 @@ export function onNewScreen(screen: string, replaceLast = false): void { export function init(): void { window.addEventListener("hashchange", onHashChange); } + +const ScreenAfterLoginStorageKey = "mx_screen_after_login"; +function getStoredInitialScreenAfterLogin(): ReturnType | undefined { + const screenAfterLogin = sessionStorage.getItem(ScreenAfterLoginStorageKey); + + return screenAfterLogin ? JSON.parse(screenAfterLogin) : undefined; +} + +function setInitialScreenAfterLogin(screenAfterLogin?: ReturnType): void { + if (screenAfterLogin?.screen) { + sessionStorage.setItem(ScreenAfterLoginStorageKey, JSON.stringify(screenAfterLogin)); + } +} + +/** + * Get the initial screen to be displayed after login, + * for example when trying to view a room via a link before logging in + * + * If the current URL has a screen set that in session storage + * Then retrieve the screen from session storage and return it + * Using session storage allows us to remember login fragments from when returning from OIDC login + * @returns screen and params or undefined + */ +export function getInitialScreenAfterLogin(location: Location): ReturnType | undefined { + const screenAfterLogin = getScreenFromLocation(location); + + if (screenAfterLogin.screen || screenAfterLogin.params) { + setInitialScreenAfterLogin(screenAfterLogin); + } + + const storedScreenAfterLogin = getStoredInitialScreenAfterLogin(); + return storedScreenAfterLogin; +} diff --git a/test/unit-tests/vector/routing-test.ts b/test/unit-tests/vector/routing-test.ts index 3b8df5302d..bb7adccacb 100644 --- a/test/unit-tests/vector/routing-test.ts +++ b/test/unit-tests/vector/routing-test.ts @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -import { onNewScreen } from "../../../src/vector/routing"; +import { getInitialScreenAfterLogin, onNewScreen } from "../../../src/vector/routing"; describe("onNewScreen", () => { it("should replace history if stripping via fields", () => { @@ -45,3 +45,61 @@ describe("onNewScreen", () => { expect(window.location.replace).not.toHaveBeenCalled(); }); }); + +describe("getInitialScreenAfterLogin", () => { + beforeEach(() => { + jest.spyOn(sessionStorage.__proto__, "getItem").mockClear().mockReturnValue(null); + jest.spyOn(sessionStorage.__proto__, "setItem").mockClear(); + }); + + const makeMockLocation = (hash = "") => { + const url = new URL("https://test.org"); + url.hash = hash; + return url as unknown as Location; + }; + + describe("when current url has no hash", () => { + it("does not set an initial screen in session storage", () => { + getInitialScreenAfterLogin(makeMockLocation()); + expect(sessionStorage.setItem).not.toHaveBeenCalled(); + }); + + it("returns undefined when there is no initial screen in session storage", () => { + expect(getInitialScreenAfterLogin(makeMockLocation())).toBeUndefined(); + }); + + it("returns initial screen from session storage", () => { + const screen = { + screen: "/room/!test", + }; + jest.spyOn(sessionStorage.__proto__, "getItem").mockReturnValue(JSON.stringify(screen)); + expect(getInitialScreenAfterLogin(makeMockLocation())).toEqual(screen); + }); + }); + + describe("when current url has a hash", () => { + it("sets an initial screen in session storage", () => { + const hash = "/room/!test"; + getInitialScreenAfterLogin(makeMockLocation(hash)); + expect(sessionStorage.setItem).toHaveBeenCalledWith( + "mx_screen_after_login", + JSON.stringify({ + screen: "room/!test", + params: {}, + }), + ); + }); + + it("sets an initial screen in session storage with params", () => { + const hash = "/room/!test?param=test"; + getInitialScreenAfterLogin(makeMockLocation(hash)); + expect(sessionStorage.setItem).toHaveBeenCalledWith( + "mx_screen_after_login", + JSON.stringify({ + screen: "room/!test", + params: { param: "test" }, + }), + ); + }); + }); +});