diff --git a/src/components/views/elements/ImageView.tsx b/src/components/views/elements/ImageView.tsx index 8032f07b6e..711a221994 100644 --- a/src/components/views/elements/ImageView.tsx +++ b/src/components/views/elements/ImageView.tsx @@ -8,9 +8,9 @@ SPDX-License-Identifier: AGPL-3.0-only OR GPL-3.0-only Please see LICENSE files in the repository root for full details. */ -import React, { createRef, CSSProperties } from "react"; +import React, { createRef, CSSProperties, useRef, useState } from "react"; import FocusLock from "react-focus-lock"; -import { MatrixEvent } from "matrix-js-sdk/src/matrix"; +import { MatrixEvent, parseErrorResponse } from "matrix-js-sdk/src/matrix"; import { _t } from "../../../languageHandler"; import MemberAvatar from "../avatars/MemberAvatar"; @@ -30,6 +30,9 @@ import { KeyBindingAction } from "../../../accessibility/KeyboardShortcuts"; import { getKeyBindingsManager } from "../../../KeyBindingsManager"; import { presentableTextForFile } from "../../../utils/FileUtils"; import AccessibleButton from "./AccessibleButton"; +import Modal from "../../../Modal"; +import ErrorDialog from "../dialogs/ErrorDialog"; +import { FileDownloader } from "../../../utils/FileDownloader"; // Max scale to keep gaps around the image const MAX_SCALE = 0.95; @@ -309,15 +312,6 @@ export default class ImageView extends React.Component { this.setZoomAndRotation(cur + 90); }; - private onDownloadClick = (): void => { - const a = document.createElement("a"); - a.href = this.props.src; - if (this.props.name) a.download = this.props.name; - a.target = "_blank"; - a.rel = "noreferrer noopener"; - a.click(); - }; - private onOpenContextMenu = (): void => { this.setState({ contextMenuDisplayed: true, @@ -555,11 +549,7 @@ export default class ImageView extends React.Component { title={_t("lightbox|rotate_right")} onClick={this.onRotateClockwiseClick} /> - + {contextMenuButton} { ); } } + +function DownloadButton({ url, fileName }: { url: string; fileName?: string }): JSX.Element { + const downloader = useRef(new FileDownloader()).current; + const [loading, setLoading] = useState(false); + const blobRef = useRef(); + + function showError(e: unknown): void { + Modal.createDialog(ErrorDialog, { + title: _t("timeline|download_failed"), + description: ( + <> +
{_t("timeline|download_failed_description")}
+
{e instanceof Error ? e.toString() : ""}
+ + ), + }); + setLoading(false); + } + + const onDownloadClick = async (): Promise => { + try { + if (loading) return; + setLoading(true); + + if (blobRef.current) { + // Cheat and trigger a download, again. + return downloadBlob(blobRef.current); + } + + const res = await fetch(url); + if (!res.ok) { + throw parseErrorResponse(res, await res.text()); + } + const blob = await res.blob(); + blobRef.current = blob; + await downloadBlob(blob); + } catch (e) { + showError(e); + } + }; + + async function downloadBlob(blob: Blob): Promise { + await downloader.download({ + blob, + name: fileName ?? _t("common|image"), + }); + setLoading(false); + } + + return ( + + ); +} diff --git a/test/unit-tests/components/views/elements/ImageView-test.tsx b/test/unit-tests/components/views/elements/ImageView-test.tsx index 48a312ed3a..4a23d847cb 100644 --- a/test/unit-tests/components/views/elements/ImageView-test.tsx +++ b/test/unit-tests/components/views/elements/ImageView-test.tsx @@ -7,13 +7,57 @@ */ import React from "react"; -import { render } from "jest-matrix-react"; +import { mocked } from "jest-mock"; +import { render, fireEvent, waitFor } from "jest-matrix-react"; +import fetchMock from "fetch-mock-jest"; import ImageView from "../../../../../src/components/views/elements/ImageView"; +import { FileDownloader } from "../../../../../src/utils/FileDownloader"; +import Modal from "../../../../../src/Modal"; +import ErrorDialog from "../../../../../src/components/views/dialogs/ErrorDialog"; + +jest.mock("../../../../../src/utils/FileDownloader"); describe("", () => { + beforeEach(() => { + jest.resetAllMocks(); + fetchMock.reset(); + }); + it("renders correctly", () => { const { container } = render(); expect(container).toMatchSnapshot(); }); + + it("should download on click", async () => { + fetchMock.get("https://example.com/image.png", "TESTFILE"); + const { getByRole } = render( + , + ); + fireEvent.click(getByRole("button", { name: "Download" })); + await waitFor(() => + expect(mocked(FileDownloader).mock.instances[0].download).toHaveBeenCalledWith({ + blob: expect.anything(), + name: "filename.png", + }), + ); + expect(fetchMock).toHaveFetched("https://example.com/image.png"); + }); + + it("should handle download errors", async () => { + const modalSpy = jest.spyOn(Modal, "createDialog"); + fetchMock.get("https://example.com/image.png", { status: 500 }); + const { getByRole } = render( + , + ); + fireEvent.click(getByRole("button", { name: "Download" })); + await waitFor(() => + expect(modalSpy).toHaveBeenCalledWith( + ErrorDialog, + expect.objectContaining({ + title: "Download failed", + }), + ), + ); + }); });