Convert the remaining media repo code to async / await. (#7947)
							parent
							
								
									8553f46498
								
							
						
					
					
						commit
						68626ff8e9
					
				|  | @ -0,0 +1 @@ | |||
| Convert various parts of the codebase to async/await. | ||||
|  | @ -17,7 +17,9 @@ | |||
| import logging | ||||
| import os | ||||
| import urllib | ||||
| from typing import Awaitable | ||||
| 
 | ||||
| from twisted.internet.interfaces import IConsumer | ||||
| from twisted.protocols.basic import FileSender | ||||
| 
 | ||||
| from synapse.api.errors import Codes, SynapseError, cs_error | ||||
|  | @ -240,14 +242,14 @@ class Responder(object): | |||
|     held can be cleaned up. | ||||
|     """ | ||||
| 
 | ||||
|     def write_to_consumer(self, consumer): | ||||
|     def write_to_consumer(self, consumer: IConsumer) -> Awaitable: | ||||
|         """Stream response into consumer | ||||
| 
 | ||||
|         Args: | ||||
|             consumer (IConsumer) | ||||
|             consumer: The consumer to stream into. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: Resolves once the response has finished being written | ||||
|             Resolves once the response has finished being written | ||||
|         """ | ||||
|         pass | ||||
| 
 | ||||
|  |  | |||
|  | @ -18,10 +18,11 @@ import errno | |||
| import logging | ||||
| import os | ||||
| import shutil | ||||
| from typing import Dict, Tuple | ||||
| from typing import IO, Dict, Optional, Tuple | ||||
| 
 | ||||
| import twisted.internet.error | ||||
| import twisted.web.http | ||||
| from twisted.web.http import Request | ||||
| from twisted.web.resource import Resource | ||||
| 
 | ||||
| from synapse.api.errors import ( | ||||
|  | @ -40,6 +41,7 @@ from synapse.util.stringutils import random_string | |||
| 
 | ||||
| from ._base import ( | ||||
|     FileInfo, | ||||
|     Responder, | ||||
|     get_filename_from_headers, | ||||
|     respond_404, | ||||
|     respond_with_responder, | ||||
|  | @ -135,19 +137,24 @@ class MediaRepository(object): | |||
|             self.recently_accessed_locals.add(media_id) | ||||
| 
 | ||||
|     async def create_content( | ||||
|         self, media_type, upload_name, content, content_length, auth_user | ||||
|     ): | ||||
|         self, | ||||
|         media_type: str, | ||||
|         upload_name: str, | ||||
|         content: IO, | ||||
|         content_length: int, | ||||
|         auth_user: str, | ||||
|     ) -> str: | ||||
|         """Store uploaded content for a local user and return the mxc URL | ||||
| 
 | ||||
|         Args: | ||||
|             media_type(str): The content type of the file | ||||
|             upload_name(str): The name of the file | ||||
|             media_type: The content type of the file | ||||
|             upload_name: The name of the file | ||||
|             content: A file like object that is the content to store | ||||
|             content_length(int): The length of the content | ||||
|             auth_user(str): The user_id of the uploader | ||||
|             content_length: The length of the content | ||||
|             auth_user: The user_id of the uploader | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[str]: The mxc url of the stored content | ||||
|             The mxc url of the stored content | ||||
|         """ | ||||
|         media_id = random_string(24) | ||||
| 
 | ||||
|  | @ -170,19 +177,20 @@ class MediaRepository(object): | |||
| 
 | ||||
|         return "mxc://%s/%s" % (self.server_name, media_id) | ||||
| 
 | ||||
|     async def get_local_media(self, request, media_id, name): | ||||
|     async def get_local_media( | ||||
|         self, request: Request, media_id: str, name: Optional[str] | ||||
|     ) -> None: | ||||
|         """Responds to reqests for local media, if exists, or returns 404. | ||||
| 
 | ||||
|         Args: | ||||
|             request(twisted.web.http.Request) | ||||
|             media_id (str): The media ID of the content. (This is the same as | ||||
|             request: The incoming request. | ||||
|             media_id: The media ID of the content. (This is the same as | ||||
|                 the file_id for local content.) | ||||
|             name (str|None): Optional name that, if specified, will be used as | ||||
|             name: Optional name that, if specified, will be used as | ||||
|                 the filename in the Content-Disposition header of the response. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: Resolves once a response has successfully been written | ||||
|                 to request | ||||
|             Resolves once a response has successfully been written to request | ||||
|         """ | ||||
|         media_info = await self.store.get_local_media(media_id) | ||||
|         if not media_info or media_info["quarantined_by"]: | ||||
|  | @ -203,20 +211,20 @@ class MediaRepository(object): | |||
|             request, responder, media_type, media_length, upload_name | ||||
|         ) | ||||
| 
 | ||||
|     async def get_remote_media(self, request, server_name, media_id, name): | ||||
|     async def get_remote_media( | ||||
|         self, request: Request, server_name: str, media_id: str, name: Optional[str] | ||||
|     ) -> None: | ||||
|         """Respond to requests for remote media. | ||||
| 
 | ||||
|         Args: | ||||
|             request(twisted.web.http.Request) | ||||
|             server_name (str): Remote server_name where the media originated. | ||||
|             media_id (str): The media ID of the content (as defined by the | ||||
|                 remote server). | ||||
|             name (str|None): Optional name that, if specified, will be used as | ||||
|             request: The incoming request. | ||||
|             server_name: Remote server_name where the media originated. | ||||
|             media_id: The media ID of the content (as defined by the remote server). | ||||
|             name: Optional name that, if specified, will be used as | ||||
|                 the filename in the Content-Disposition header of the response. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: Resolves once a response has successfully been written | ||||
|                 to request | ||||
|             Resolves once a response has successfully been written to request | ||||
|         """ | ||||
|         if ( | ||||
|             self.federation_domain_whitelist is not None | ||||
|  | @ -245,17 +253,16 @@ class MediaRepository(object): | |||
|         else: | ||||
|             respond_404(request) | ||||
| 
 | ||||
|     async def get_remote_media_info(self, server_name, media_id): | ||||
|     async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: | ||||
|         """Gets the media info associated with the remote file, downloading | ||||
|         if necessary. | ||||
| 
 | ||||
|         Args: | ||||
|             server_name (str): Remote server_name where the media originated. | ||||
|             media_id (str): The media ID of the content (as defined by the | ||||
|                 remote server). | ||||
|             server_name: Remote server_name where the media originated. | ||||
|             media_id: The media ID of the content (as defined by the remote server). | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[dict]: The media_info of the file | ||||
|             The media info of the file | ||||
|         """ | ||||
|         if ( | ||||
|             self.federation_domain_whitelist is not None | ||||
|  | @ -278,7 +285,9 @@ class MediaRepository(object): | |||
| 
 | ||||
|         return media_info | ||||
| 
 | ||||
|     async def _get_remote_media_impl(self, server_name, media_id): | ||||
|     async def _get_remote_media_impl( | ||||
|         self, server_name: str, media_id: str | ||||
|     ) -> Tuple[Optional[Responder], dict]: | ||||
|         """Looks for media in local cache, if not there then attempt to | ||||
|         download from remote server. | ||||
| 
 | ||||
|  | @ -288,7 +297,7 @@ class MediaRepository(object): | |||
|                 remote server). | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[(Responder, media_info)] | ||||
|             A tuple of responder and the media info of the file. | ||||
|         """ | ||||
|         media_info = await self.store.get_cached_remote_media(server_name, media_id) | ||||
| 
 | ||||
|  | @ -319,19 +328,21 @@ class MediaRepository(object): | |||
|         responder = await self.media_storage.fetch_media(file_info) | ||||
|         return responder, media_info | ||||
| 
 | ||||
|     async def _download_remote_file(self, server_name, media_id, file_id): | ||||
|     async def _download_remote_file( | ||||
|         self, server_name: str, media_id: str, file_id: str | ||||
|     ) -> dict: | ||||
|         """Attempt to download the remote file from the given server name, | ||||
|         using the given file_id as the local id. | ||||
| 
 | ||||
|         Args: | ||||
|             server_name (str): Originating server | ||||
|             media_id (str): The media ID of the content (as defined by the | ||||
|             server_name: Originating server | ||||
|             media_id: The media ID of the content (as defined by the | ||||
|                 remote server). This is different than the file_id, which is | ||||
|                 locally generated. | ||||
|             file_id (str): Local file ID | ||||
|             file_id: Local file ID | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[MediaInfo] | ||||
|             The media info of the file. | ||||
|         """ | ||||
| 
 | ||||
|         file_info = FileInfo(server_name=server_name, file_id=file_id) | ||||
|  | @ -549,25 +560,31 @@ class MediaRepository(object): | |||
|             return output_path | ||||
| 
 | ||||
|     async def _generate_thumbnails( | ||||
|         self, server_name, media_id, file_id, media_type, url_cache=False | ||||
|     ): | ||||
|         self, | ||||
|         server_name: Optional[str], | ||||
|         media_id: str, | ||||
|         file_id: str, | ||||
|         media_type: str, | ||||
|         url_cache: bool = False, | ||||
|     ) -> Optional[dict]: | ||||
|         """Generate and store thumbnails for an image. | ||||
| 
 | ||||
|         Args: | ||||
|             server_name (str|None): The server name if remote media, else None if local | ||||
|             media_id (str): The media ID of the content. (This is the same as | ||||
|             server_name: The server name if remote media, else None if local | ||||
|             media_id: The media ID of the content. (This is the same as | ||||
|                 the file_id for local content) | ||||
|             file_id (str): Local file ID | ||||
|             media_type (str): The content type of the file | ||||
|             url_cache (bool): If we are thumbnailing images downloaded for the URL cache, | ||||
|             file_id: Local file ID | ||||
|             media_type: The content type of the file | ||||
|             url_cache: If we are thumbnailing images downloaded for the URL cache, | ||||
|                 used exclusively by the url previewer | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[dict]: Dict with "width" and "height" keys of original image | ||||
|             Dict with "width" and "height" keys of original image or None if the | ||||
|             media cannot be thumbnailed. | ||||
|         """ | ||||
|         requirements = self._get_thumbnail_requirements(media_type) | ||||
|         if not requirements: | ||||
|             return | ||||
|             return None | ||||
| 
 | ||||
|         input_path = await self.media_storage.ensure_media_is_in_local_cache( | ||||
|             FileInfo(server_name, file_id, url_cache=url_cache) | ||||
|  | @ -584,7 +601,7 @@ class MediaRepository(object): | |||
|                 m_height, | ||||
|                 self.max_image_pixels, | ||||
|             ) | ||||
|             return | ||||
|             return None | ||||
| 
 | ||||
|         if thumbnailer.transpose_method is not None: | ||||
|             m_width, m_height = await defer_to_thread( | ||||
|  |  | |||
|  | @ -12,13 +12,12 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import contextlib | ||||
| import inspect | ||||
| import logging | ||||
| import os | ||||
| import shutil | ||||
| from typing import Optional | ||||
| from typing import IO, TYPE_CHECKING, Any, Optional, Sequence | ||||
| 
 | ||||
| from twisted.protocols.basic import FileSender | ||||
| 
 | ||||
|  | @ -26,6 +25,12 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable | |||
| from synapse.util.file_consumer import BackgroundFileConsumer | ||||
| 
 | ||||
| from ._base import FileInfo, Responder | ||||
| from .filepath import MediaFilePaths | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|     from synapse.server import HomeServer | ||||
| 
 | ||||
|     from .storage_provider import StorageProvider | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|  | @ -34,20 +39,25 @@ class MediaStorage(object): | |||
|     """Responsible for storing/fetching files from local sources. | ||||
| 
 | ||||
|     Args: | ||||
|         hs (synapse.server.Homeserver) | ||||
|         local_media_directory (str): Base path where we store media on disk | ||||
|         filepaths (MediaFilePaths) | ||||
|         storage_providers ([StorageProvider]): List of StorageProvider that are | ||||
|             used to fetch and store files. | ||||
|         hs | ||||
|         local_media_directory: Base path where we store media on disk | ||||
|         filepaths | ||||
|         storage_providers: List of StorageProvider that are used to fetch and store files. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, hs, local_media_directory, filepaths, storage_providers): | ||||
|     def __init__( | ||||
|         self, | ||||
|         hs: "HomeServer", | ||||
|         local_media_directory: str, | ||||
|         filepaths: MediaFilePaths, | ||||
|         storage_providers: Sequence["StorageProvider"], | ||||
|     ): | ||||
|         self.hs = hs | ||||
|         self.local_media_directory = local_media_directory | ||||
|         self.filepaths = filepaths | ||||
|         self.storage_providers = storage_providers | ||||
| 
 | ||||
|     async def store_file(self, source, file_info: FileInfo) -> str: | ||||
|     async def store_file(self, source: IO, file_info: FileInfo) -> str: | ||||
|         """Write `source` to the on disk media store, and also any other | ||||
|         configured storage providers | ||||
| 
 | ||||
|  | @ -69,7 +79,7 @@ class MediaStorage(object): | |||
|         return fname | ||||
| 
 | ||||
|     @contextlib.contextmanager | ||||
|     def store_into_file(self, file_info): | ||||
|     def store_into_file(self, file_info: FileInfo): | ||||
|         """Context manager used to get a file like object to write into, as | ||||
|         described by file_info. | ||||
| 
 | ||||
|  | @ -85,7 +95,7 @@ class MediaStorage(object): | |||
|         error. | ||||
| 
 | ||||
|         Args: | ||||
|             file_info (FileInfo): Info about the file to store | ||||
|             file_info: Info about the file to store | ||||
| 
 | ||||
|         Example: | ||||
| 
 | ||||
|  | @ -143,9 +153,9 @@ class MediaStorage(object): | |||
|             return FileResponder(open(local_path, "rb")) | ||||
| 
 | ||||
|         for provider in self.storage_providers: | ||||
|             res = provider.fetch(path, file_info) | ||||
|             # Fetch is supposed to return an Awaitable, but guard against | ||||
|             # improper implementations. | ||||
|             res = provider.fetch(path, file_info)  # type: Any | ||||
|             # Fetch is supposed to return an Awaitable[Responder], but guard | ||||
|             # against improper implementations. | ||||
|             if inspect.isawaitable(res): | ||||
|                 res = await res | ||||
|             if res: | ||||
|  | @ -174,9 +184,9 @@ class MediaStorage(object): | |||
|             os.makedirs(dirname) | ||||
| 
 | ||||
|         for provider in self.storage_providers: | ||||
|             res = provider.fetch(path, file_info) | ||||
|             # Fetch is supposed to return an Awaitable, but guard against | ||||
|             # improper implementations. | ||||
|             res = provider.fetch(path, file_info)  # type: Any | ||||
|             # Fetch is supposed to return an Awaitable[Responder], but guard | ||||
|             # against improper implementations. | ||||
|             if inspect.isawaitable(res): | ||||
|                 res = await res | ||||
|             if res: | ||||
|  | @ -190,17 +200,11 @@ class MediaStorage(object): | |||
| 
 | ||||
|         raise Exception("file could not be found") | ||||
| 
 | ||||
|     def _file_info_to_path(self, file_info): | ||||
|     def _file_info_to_path(self, file_info: FileInfo) -> str: | ||||
|         """Converts file_info into a relative path. | ||||
| 
 | ||||
|         The path is suitable for storing files under a directory, e.g. used to | ||||
|         store files on local FS under the base media repository directory. | ||||
| 
 | ||||
|         Args: | ||||
|             file_info (FileInfo) | ||||
| 
 | ||||
|         Returns: | ||||
|             str | ||||
|         """ | ||||
|         if file_info.url_cache: | ||||
|             if file_info.thumbnail: | ||||
|  |  | |||
|  | @ -231,16 +231,16 @@ class PreviewUrlResource(DirectServeJsonResource): | |||
|         og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) | ||||
|         respond_with_json_bytes(request, 200, og, send_cors=True) | ||||
| 
 | ||||
|     async def _do_preview(self, url, user, ts): | ||||
|     async def _do_preview(self, url: str, user: str, ts: int) -> bytes: | ||||
|         """Check the db, and download the URL and build a preview | ||||
| 
 | ||||
|         Args: | ||||
|             url (str): | ||||
|             user (str): | ||||
|             ts (int): | ||||
|             url: The URL to preview. | ||||
|             user: The user requesting the preview. | ||||
|             ts: The timestamp requested for the preview. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[bytes]: json-encoded og data | ||||
|             json-encoded og data | ||||
|         """ | ||||
|         # check the URL cache in the DB (which will also provide us with | ||||
|         # historical previews, if we have any) | ||||
|  |  | |||
|  | @ -16,62 +16,62 @@ | |||
| import logging | ||||
| import os | ||||
| import shutil | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| from typing import Optional | ||||
| 
 | ||||
| from synapse.config._base import Config | ||||
| from synapse.logging.context import defer_to_thread, run_in_background | ||||
| 
 | ||||
| from ._base import FileInfo, Responder | ||||
| from .media_storage import FileResponder | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class StorageProvider(object): | ||||
| class StorageProvider: | ||||
|     """A storage provider is a service that can store uploaded media and | ||||
|     retrieve them. | ||||
|     """ | ||||
| 
 | ||||
|     def store_file(self, path, file_info): | ||||
|     async def store_file(self, path: str, file_info: FileInfo): | ||||
|         """Store the file described by file_info. The actual contents can be | ||||
|         retrieved by reading the file in file_info.upload_path. | ||||
| 
 | ||||
|         Args: | ||||
|             path (str): Relative path of file in local cache | ||||
|             file_info (FileInfo) | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred | ||||
|             path: Relative path of file in local cache | ||||
|             file_info: The metadata of the file. | ||||
|         """ | ||||
|         pass | ||||
| 
 | ||||
|     def fetch(self, path, file_info): | ||||
|     async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: | ||||
|         """Attempt to fetch the file described by file_info and stream it | ||||
|         into writer. | ||||
| 
 | ||||
|         Args: | ||||
|             path (str): Relative path of file in local cache | ||||
|             file_info (FileInfo) | ||||
|             path: Relative path of file in local cache | ||||
|             file_info: The metadata of the file. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred(Responder): Returns a Responder if the provider has the file, | ||||
|                 otherwise returns None. | ||||
|             Returns a Responder if the provider has the file, otherwise returns None. | ||||
|         """ | ||||
|         pass | ||||
| 
 | ||||
| 
 | ||||
| class StorageProviderWrapper(StorageProvider): | ||||
|     """Wraps a storage provider and provides various config options | ||||
| 
 | ||||
|     Args: | ||||
|         backend (StorageProvider) | ||||
|         store_local (bool): Whether to store new local files or not. | ||||
|         store_synchronous (bool): Whether to wait for file to be successfully | ||||
|         backend: The storage provider to wrap. | ||||
|         store_local: Whether to store new local files or not. | ||||
|         store_synchronous: Whether to wait for file to be successfully | ||||
|             uploaded, or todo the upload in the background. | ||||
|         store_remote (bool): Whether remote media should be uploaded | ||||
|         store_remote: Whether remote media should be uploaded | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, backend, store_local, store_synchronous, store_remote): | ||||
|     def __init__( | ||||
|         self, | ||||
|         backend: StorageProvider, | ||||
|         store_local: bool, | ||||
|         store_synchronous: bool, | ||||
|         store_remote: bool, | ||||
|     ): | ||||
|         self.backend = backend | ||||
|         self.store_local = store_local | ||||
|         self.store_synchronous = store_synchronous | ||||
|  | @ -80,15 +80,15 @@ class StorageProviderWrapper(StorageProvider): | |||
|     def __str__(self): | ||||
|         return "StorageProviderWrapper[%s]" % (self.backend,) | ||||
| 
 | ||||
|     def store_file(self, path, file_info): | ||||
|     async def store_file(self, path, file_info): | ||||
|         if not file_info.server_name and not self.store_local: | ||||
|             return defer.succeed(None) | ||||
|             return None | ||||
| 
 | ||||
|         if file_info.server_name and not self.store_remote: | ||||
|             return defer.succeed(None) | ||||
|             return None | ||||
| 
 | ||||
|         if self.store_synchronous: | ||||
|             return self.backend.store_file(path, file_info) | ||||
|             return await self.backend.store_file(path, file_info) | ||||
|         else: | ||||
|             # TODO: Handle errors. | ||||
|             def store(): | ||||
|  | @ -98,10 +98,10 @@ class StorageProviderWrapper(StorageProvider): | |||
|                     logger.exception("Error storing file") | ||||
| 
 | ||||
|             run_in_background(store) | ||||
|             return defer.succeed(None) | ||||
|             return None | ||||
| 
 | ||||
|     def fetch(self, path, file_info): | ||||
|         return self.backend.fetch(path, file_info) | ||||
|     async def fetch(self, path, file_info): | ||||
|         return await self.backend.fetch(path, file_info) | ||||
| 
 | ||||
| 
 | ||||
| class FileStorageProviderBackend(StorageProvider): | ||||
|  | @ -120,7 +120,7 @@ class FileStorageProviderBackend(StorageProvider): | |||
|     def __str__(self): | ||||
|         return "FileStorageProviderBackend[%s]" % (self.base_directory,) | ||||
| 
 | ||||
|     def store_file(self, path, file_info): | ||||
|     async def store_file(self, path, file_info): | ||||
|         """See StorageProvider.store_file""" | ||||
| 
 | ||||
|         primary_fname = os.path.join(self.cache_directory, path) | ||||
|  | @ -130,11 +130,11 @@ class FileStorageProviderBackend(StorageProvider): | |||
|         if not os.path.exists(dirname): | ||||
|             os.makedirs(dirname) | ||||
| 
 | ||||
|         return defer_to_thread( | ||||
|         return await defer_to_thread( | ||||
|             self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname | ||||
|         ) | ||||
| 
 | ||||
|     def fetch(self, path, file_info): | ||||
|     async def fetch(self, path, file_info): | ||||
|         """See StorageProvider.fetch""" | ||||
| 
 | ||||
|         backup_fname = os.path.join(self.base_directory, path) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Patrick Cloke
						Patrick Cloke