Convert more of the media code to async/await (#7873)

pull/7953/head
Patrick Cloke 2020-07-24 09:39:02 -04:00 committed by GitHub
parent 6a080ea184
commit 5ea29d7f85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 34 deletions

1
changelog.d/7873.misc Normal file
View File

@ -0,0 +1 @@
Convert more media code to async/await.

View File

@ -18,7 +18,6 @@ import logging
import os import os
import urllib import urllib
from twisted.internet import defer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error from synapse.api.errors import Codes, SynapseError, cs_error
@ -77,8 +76,9 @@ def respond_404(request):
) )
@defer.inlineCallbacks async def respond_with_file(
def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None): request, media_type, file_path, file_size=None, upload_name=None
):
logger.debug("Responding with %r", file_path) logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path): if os.path.isfile(file_path):
@ -89,7 +89,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam
add_file_headers(request, media_type, file_size, upload_name) add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
finish_request(request) finish_request(request)
else: else:
@ -198,8 +198,9 @@ def _can_encode_filename_as_token(x):
return True return True
@defer.inlineCallbacks async def respond_with_responder(
def respond_with_responder(request, responder, media_type, file_size, upload_name=None): request, responder, media_type, file_size, upload_name=None
):
"""Responds to the request with given responder. If responder is None then """Responds to the request with given responder. If responder is None then
returns 404. returns 404.
@ -218,7 +219,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
add_file_headers(request, media_type, file_size, upload_name) add_file_headers(request, media_type, file_size, upload_name)
try: try:
with responder: with responder:
yield responder.write_to_consumer(request) await responder.write_to_consumer(request)
except Exception as e: except Exception as e:
# The majority of the time this will be due to the client having gone # The majority of the time this will be due to the client having gone
# away. Unfortunately, Twisted simply throws a generic exception at us # away. Unfortunately, Twisted simply throws a generic exception at us

View File

@ -14,17 +14,18 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import inspect
import logging import logging
import os import os
import shutil import shutil
from typing import Optional
from twisted.internet import defer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import Responder from ._base import FileInfo, Responder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,25 +47,24 @@ class MediaStorage(object):
self.filepaths = filepaths self.filepaths = filepaths
self.storage_providers = storage_providers self.storage_providers = storage_providers
@defer.inlineCallbacks async def store_file(self, source, file_info: FileInfo) -> str:
def store_file(self, source, file_info):
"""Write `source` to the on disk media store, and also any other """Write `source` to the on disk media store, and also any other
configured storage providers configured storage providers
Args: Args:
source: A file like object that should be written source: A file like object that should be written
file_info (FileInfo): Info about the file to store file_info: Info about the file to store
Returns: Returns:
Deferred[str]: the file path written to in the primary media store the file path written to in the primary media store
""" """
with self.store_into_file(file_info) as (f, fname, finish_cb): with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository # Write to the main repository
yield defer_to_thread( await defer_to_thread(
self.hs.get_reactor(), _write_file_synchronously, source, f self.hs.get_reactor(), _write_file_synchronously, source, f
) )
yield finish_cb() await finish_cb()
return fname return fname
@ -75,7 +75,7 @@ class MediaStorage(object):
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
like object that can be written to, fname is the absolute path of file like object that can be written to, fname is the absolute path of file
on disk, and finish_cb is a function that returns a Deferred. on disk, and finish_cb is a function that returns an awaitable.
fname can be used to read the contents from after upload, e.g. to fname can be used to read the contents from after upload, e.g. to
generate thumbnails. generate thumbnails.
@ -91,7 +91,7 @@ class MediaStorage(object):
with media_storage.store_into_file(info) as (f, fname, finish_cb): with media_storage.store_into_file(info) as (f, fname, finish_cb):
# .. write into f ... # .. write into f ...
yield finish_cb() await finish_cb()
""" """
path = self._file_info_to_path(file_info) path = self._file_info_to_path(file_info)
@ -103,10 +103,13 @@ class MediaStorage(object):
finished_called = [False] finished_called = [False]
@defer.inlineCallbacks async def finish():
def finish():
for provider in self.storage_providers: for provider in self.storage_providers:
yield provider.store_file(path, file_info) # store_file is supposed to return an Awaitable, but guard
# against improper implementations.
result = provider.store_file(path, file_info)
if inspect.isawaitable(result):
await result
finished_called[0] = True finished_called[0] = True
@ -123,17 +126,15 @@ class MediaStorage(object):
if not finished_called: if not finished_called:
raise Exception("Finished callback not called") raise Exception("Finished callback not called")
@defer.inlineCallbacks async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
def fetch_media(self, file_info):
"""Attempts to fetch media described by file_info from the local cache """Attempts to fetch media described by file_info from the local cache
and configured storage providers. and configured storage providers.
Args: Args:
file_info (FileInfo) file_info
Returns: Returns:
Deferred[Responder|None]: Returns a Responder if the file was found, Returns a Responder if the file was found, otherwise None.
otherwise None.
""" """
path = self._file_info_to_path(file_info) path = self._file_info_to_path(file_info)
@ -142,23 +143,26 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb")) return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers: for provider in self.storage_providers:
res = yield provider.fetch(path, file_info) res = provider.fetch(path, file_info)
# Fetch is supposed to return an Awaitable, but guard against
# improper implementations.
if inspect.isawaitable(res):
res = await res
if res: if res:
logger.debug("Streaming %s from %s", path, provider) logger.debug("Streaming %s from %s", path, provider)
return res return res
return None return None
@defer.inlineCallbacks async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
def ensure_media_is_in_local_cache(self, file_info):
"""Ensures that the given file is in the local cache. Attempts to """Ensures that the given file is in the local cache. Attempts to
download it from storage providers if it isn't. download it from storage providers if it isn't.
Args: Args:
file_info (FileInfo) file_info
Returns: Returns:
Deferred[str]: Full path to local file Full path to local file
""" """
path = self._file_info_to_path(file_info) path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path) local_path = os.path.join(self.local_media_directory, path)
@ -170,14 +174,18 @@ class MediaStorage(object):
os.makedirs(dirname) os.makedirs(dirname)
for provider in self.storage_providers: for provider in self.storage_providers:
res = yield provider.fetch(path, file_info) res = provider.fetch(path, file_info)
# Fetch is supposed to return an Awaitable, but guard against
# improper implementations.
if inspect.isawaitable(res):
res = await res
if res: if res:
with res: with res:
consumer = BackgroundFileConsumer( consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.hs.get_reactor() open(local_path, "wb"), self.hs.get_reactor()
) )
yield res.write_to_consumer(consumer) await res.write_to_consumer(consumer)
yield consumer.wait() await consumer.wait()
return local_path return local_path
raise Exception("file could not be found") raise Exception("file could not be found")

View File

@ -26,6 +26,7 @@ import attr
from parameterized import parameterized_class from parameterized import parameterized_class
from PIL import Image as Image from PIL import Image as Image
from twisted.internet import defer
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -77,7 +78,9 @@ class MediaStorageTests(unittest.HomeserverTestCase):
# This uses a real blocking threadpool so we have to wait for it to be # This uses a real blocking threadpool so we have to wait for it to be
# actually done :/ # actually done :/
x = self.media_storage.ensure_media_is_in_local_cache(file_info) x = defer.ensureDeferred(
self.media_storage.ensure_media_is_in_local_cache(file_info)
)
# Hotloop until the threadpool does its job... # Hotloop until the threadpool does its job...
self.wait_on_thread(x) self.wait_on_thread(x)