Convert more of the media code to async/await (#7873)

pull/7953/head
Patrick Cloke 2020-07-24 09:39:02 -04:00 committed by GitHub
parent 6a080ea184
commit 5ea29d7f85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 34 deletions

1
changelog.d/7873.misc Normal file
View File

@ -0,0 +1 @@
Convert more media code to async/await.

View File

@ -18,7 +18,6 @@ import logging
import os
import urllib
from twisted.internet import defer
from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error
@ -77,8 +76,9 @@ def respond_404(request):
)
@defer.inlineCallbacks
def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None):
async def respond_with_file(
request, media_type, file_path, file_size=None, upload_name=None
):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
@ -89,7 +89,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
finish_request(request)
else:
@ -198,8 +198,9 @@ def _can_encode_filename_as_token(x):
return True
@defer.inlineCallbacks
def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
async def respond_with_responder(
request, responder, media_type, file_size, upload_name=None
):
"""Responds to the request with given responder. If responder is None then
returns 404.
@ -218,7 +219,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
try:
with responder:
yield responder.write_to_consumer(request)
await responder.write_to_consumer(request)
except Exception as e:
# The majority of the time this will be due to the client having gone
# away. Unfortunately, Twisted simply throws a generic exception at us

View File

@ -14,17 +14,18 @@
# limitations under the License.
import contextlib
import inspect
import logging
import os
import shutil
from typing import Optional
from twisted.internet import defer
from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import Responder
from ._base import FileInfo, Responder
logger = logging.getLogger(__name__)
@ -46,25 +47,24 @@ class MediaStorage(object):
self.filepaths = filepaths
self.storage_providers = storage_providers
@defer.inlineCallbacks
def store_file(self, source, file_info):
async def store_file(self, source, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
configured storage providers
Args:
source: A file like object that should be written
file_info (FileInfo): Info about the file to store
file_info: Info about the file to store
Returns:
Deferred[str]: the file path written to in the primary media store
the file path written to in the primary media store
"""
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
yield defer_to_thread(
await defer_to_thread(
self.hs.get_reactor(), _write_file_synchronously, source, f
)
yield finish_cb()
await finish_cb()
return fname
@ -75,7 +75,7 @@ class MediaStorage(object):
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
like object that can be written to, fname is the absolute path of file
on disk, and finish_cb is a function that returns a Deferred.
on disk, and finish_cb is a function that returns an awaitable.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
@ -91,7 +91,7 @@ class MediaStorage(object):
with media_storage.store_into_file(info) as (f, fname, finish_cb):
# .. write into f ...
yield finish_cb()
await finish_cb()
"""
path = self._file_info_to_path(file_info)
@ -103,10 +103,13 @@ class MediaStorage(object):
finished_called = [False]
@defer.inlineCallbacks
def finish():
async def finish():
for provider in self.storage_providers:
yield provider.store_file(path, file_info)
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
result = provider.store_file(path, file_info)
if inspect.isawaitable(result):
await result
finished_called[0] = True
@ -123,17 +126,15 @@ class MediaStorage(object):
if not finished_called:
raise Exception("Finished callback not called")
@defer.inlineCallbacks
def fetch_media(self, file_info):
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache
and configured storage providers.
Args:
file_info (FileInfo)
file_info
Returns:
Deferred[Responder|None]: Returns a Responder if the file was found,
otherwise None.
Returns a Responder if the file was found, otherwise None.
"""
path = self._file_info_to_path(file_info)
@ -142,23 +143,26 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
res = yield provider.fetch(path, file_info)
res = provider.fetch(path, file_info)
# Fetch is supposed to return an Awaitable, but guard against
# improper implementations.
if inspect.isawaitable(res):
res = await res
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
return None
@defer.inlineCallbacks
def ensure_media_is_in_local_cache(self, file_info):
async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
"""Ensures that the given file is in the local cache. Attempts to
download it from storage providers if it isn't.
Args:
file_info (FileInfo)
file_info
Returns:
Deferred[str]: Full path to local file
Full path to local file
"""
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
@ -170,14 +174,18 @@ class MediaStorage(object):
os.makedirs(dirname)
for provider in self.storage_providers:
res = yield provider.fetch(path, file_info)
res = provider.fetch(path, file_info)
# Fetch is supposed to return an Awaitable, but guard against
# improper implementations.
if inspect.isawaitable(res):
res = await res
if res:
with res:
consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.hs.get_reactor()
)
yield res.write_to_consumer(consumer)
yield consumer.wait()
await res.write_to_consumer(consumer)
await consumer.wait()
return local_path
raise Exception("file could not be found")

View File

@ -26,6 +26,7 @@ import attr
from parameterized import parameterized_class
from PIL import Image as Image
from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
@ -77,7 +78,9 @@ class MediaStorageTests(unittest.HomeserverTestCase):
# This uses a real blocking threadpool so we have to wait for it to be
# actually done :/
x = self.media_storage.ensure_media_is_in_local_cache(file_info)
x = defer.ensureDeferred(
self.media_storage.ensure_media_is_in_local_cache(file_info)
)
# Hotloop until the threadpool does its job...
self.wait_on_thread(x)