Fixup
parent
113f6a2767
commit
19f7864b8d
|
@ -71,7 +71,7 @@ class MediaStorage:
|
|||
|
||||
with self.store_into_file(file_info) as (f, fname, finish_cb):
|
||||
# Write to the main repository
|
||||
self.write_to_file(source, f)
|
||||
await self.write_to_file(source, f)
|
||||
await finish_cb()
|
||||
|
||||
return fname
|
||||
|
@ -116,14 +116,20 @@ class MediaStorage:
|
|||
|
||||
finished_called = [False]
|
||||
|
||||
async def finish():
|
||||
for provider in self.storage_providers:
|
||||
await provider.store_file(path, file_info)
|
||||
|
||||
finished_called[0] = True
|
||||
|
||||
try:
|
||||
with open(fname, "wb") as f:
|
||||
|
||||
async def finish():
|
||||
# Ensure that all writes have been flushed and close the
|
||||
# file.
|
||||
f.flush()
|
||||
f.close()
|
||||
|
||||
for provider in self.storage_providers:
|
||||
await provider.store_file(path, file_info)
|
||||
|
||||
finished_called[0] = True
|
||||
|
||||
yield f, fname, finish
|
||||
except Exception:
|
||||
try:
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import logging
|
||||
from binascii import unhexlify
|
||||
from typing import Tuple
|
||||
import os
|
||||
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
|
@ -144,6 +145,8 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
hs2 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
|
||||
start_count = self._count_remote_media()
|
||||
|
||||
# Make two requests without responding to the outbound media requests.
|
||||
channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
|
||||
channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
|
||||
|
@ -171,6 +174,9 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
self.assertEqual(channel2.code, 200, channel2.result["body"])
|
||||
self.assertEqual(channel2.result["body"], b"Hello!")
|
||||
|
||||
# We expect only one new file to have been persisted.
|
||||
self.assertEqual(start_count + 1, self._count_remote_media())
|
||||
|
||||
def test_download_image_race(self):
|
||||
"""Test that fetching remote *images* from two different processes at
|
||||
the same time works.
|
||||
|
@ -180,6 +186,8 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
hs2 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
|
||||
start_count = self._count_remote_thumbnails()
|
||||
|
||||
channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
|
||||
channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
|
||||
|
||||
|
@ -209,6 +217,25 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
self.assertEqual(channel2.code, 200, channel2.result["body"])
|
||||
self.assertEqual(channel2.result["body"], png_data)
|
||||
|
||||
# We expect only three new thumbnails to have been persisted.
|
||||
self.assertEqual(start_count + 3, self._count_remote_thumbnails())
|
||||
|
||||
def _count_remote_media(self) -> int:
|
||||
"""Count the number of files in our remote media directory.
|
||||
"""
|
||||
path = os.path.join(
|
||||
self.hs.get_media_repository().primary_base_path, "remote_content"
|
||||
)
|
||||
return sum(len(files) for _, _, files in os.walk(path))
|
||||
|
||||
def _count_remote_thumbnails(self) -> int:
|
||||
"""Count the number of files in our remote thumbnails directory.
|
||||
"""
|
||||
path = os.path.join(
|
||||
self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
|
||||
)
|
||||
return sum(len(files) for _, _, files in os.walk(path))
|
||||
|
||||
|
||||
def get_connection_factory():
|
||||
# this needs to happen once, but not until we are ready to run the first test
|
||||
|
|
Loading…
Reference in New Issue