pull/8682/head
Erik Johnston 2020-10-29 14:03:18 +00:00
parent 113f6a2767
commit 19f7864b8d
2 changed files with 40 additions and 7 deletions

View File

@ -71,7 +71,7 @@ class MediaStorage:
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
self.write_to_file(source, f)
await self.write_to_file(source, f)
await finish_cb()
return fname
@ -116,14 +116,20 @@ class MediaStorage:
finished_called = [False]
async def finish():
for provider in self.storage_providers:
await provider.store_file(path, file_info)
finished_called[0] = True
try:
with open(fname, "wb") as f:
async def finish():
# Ensure that all writes have been flushed and close the
# file.
f.flush()
f.close()
for provider in self.storage_providers:
await provider.store_file(path, file_info)
finished_called[0] = True
yield f, fname, finish
except Exception:
try:

View File

@ -15,6 +15,7 @@
import logging
from binascii import unhexlify
from typing import Tuple
import os
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
@ -144,6 +145,8 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
hs1 = self.make_worker_hs("synapse.app.generic_worker")
hs2 = self.make_worker_hs("synapse.app.generic_worker")
start_count = self._count_remote_media()
# Make two requests without responding to the outbound media requests.
channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
@ -171,6 +174,9 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel2.code, 200, channel2.result["body"])
self.assertEqual(channel2.result["body"], b"Hello!")
# We expect only one new file to have been persisted.
self.assertEqual(start_count + 1, self._count_remote_media())
def test_download_image_race(self):
"""Test that fetching remote *images* from two different processes at
the same time works.
@ -180,6 +186,8 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
hs1 = self.make_worker_hs("synapse.app.generic_worker")
hs2 = self.make_worker_hs("synapse.app.generic_worker")
start_count = self._count_remote_thumbnails()
channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
@ -209,6 +217,25 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel2.code, 200, channel2.result["body"])
self.assertEqual(channel2.result["body"], png_data)
# We expect only three new thumbnails to have been persisted.
self.assertEqual(start_count + 3, self._count_remote_thumbnails())
def _count_remote_media(self) -> int:
"""Count the number of files in our remote media directory.
"""
path = os.path.join(
self.hs.get_media_repository().primary_base_path, "remote_content"
)
return sum(len(files) for _, _, files in os.walk(path))
def _count_remote_thumbnails(self) -> int:
"""Count the number of files in our remote thumbnails directory.
"""
path = os.path.join(
self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
)
return sum(len(files) for _, _, files in os.walk(path))
def get_connection_factory():
# this needs to happen once, but not until we are ready to run the first test