From 19f7864b8d0bdf6003a9ec685a0af13008fd30b2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 29 Oct 2020 14:03:18 +0000 Subject: [PATCH] Fixup --- synapse/rest/media/v1/media_storage.py | 20 ++++++++++------ tests/replication/test_multi_media_repo.py | 27 ++++++++++++++++++++++ 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index d3cc994dae..268e0c8f50 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -71,7 +71,7 @@ class MediaStorage: with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - self.write_to_file(source, f) + await self.write_to_file(source, f) await finish_cb() return fname @@ -116,14 +116,20 @@ class MediaStorage: finished_called = [False] - async def finish(): - for provider in self.storage_providers: - await provider.store_file(path, file_info) - - finished_called[0] = True - try: with open(fname, "wb") as f: + + async def finish(): + # Ensure that all writes have been flushed and close the + # file. + f.flush() + f.close() + + for provider in self.storage_providers: + await provider.store_file(path, file_info) + + finished_called[0] = True + yield f, fname, finish except Exception: try: diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 91a10c84cd..cd6329f7a3 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -15,6 +15,7 @@ import logging from binascii import unhexlify from typing import Tuple +import os from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory @@ -144,6 +145,8 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): hs1 = self.make_worker_hs("synapse.app.generic_worker") hs2 = self.make_worker_hs("synapse.app.generic_worker") + start_count = self._count_remote_media() + # Make two requests without responding to the outbound media requests. channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123") channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123") @@ -171,6 +174,9 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(channel2.code, 200, channel2.result["body"]) self.assertEqual(channel2.result["body"], b"Hello!") + # We expect only one new file to have been persisted. + self.assertEqual(start_count + 1, self._count_remote_media()) + def test_download_image_race(self): """Test that fetching remote *images* from two different processes at the same time works. @@ -180,6 +186,8 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): hs1 = self.make_worker_hs("synapse.app.generic_worker") hs2 = self.make_worker_hs("synapse.app.generic_worker") + start_count = self._count_remote_thumbnails() + channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1") channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1") @@ -209,6 +217,25 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(channel2.code, 200, channel2.result["body"]) self.assertEqual(channel2.result["body"], png_data) + # We expect only three new thumbnails to have been persisted. + self.assertEqual(start_count + 3, self._count_remote_thumbnails()) + + def _count_remote_media(self) -> int: + """Count the number of files in our remote media directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_content" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + def _count_remote_thumbnails(self) -> int: + """Count the number of files in our remote thumbnails directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_thumbnail" + ) + return sum(len(files) for _, _, files in os.walk(path)) + def get_connection_factory(): # this needs to happen once, but not until we are ready to run the first test