From 9e4610cc272fc8e5db39608de83ce48360889e42 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 8 Jun 2021 08:30:48 -0400 Subject: [PATCH] Correct type hints for parse_string(s)_from_args. (#10137) --- changelog.d/10137.misc | 1 + mypy.ini | 1 + synapse/http/servlet.py | 179 ++++++++++++++--------- synapse/rest/admin/rooms.py | 2 +- synapse/rest/client/v1/login.py | 8 +- synapse/rest/client/v1/room.py | 4 +- synapse/rest/consent/consent_resource.py | 9 +- synapse/rest/media/v1/upload_resource.py | 11 +- 8 files changed, 132 insertions(+), 83 deletions(-) create mode 100644 changelog.d/10137.misc diff --git a/changelog.d/10137.misc b/changelog.d/10137.misc new file mode 100644 index 0000000000..a901f8431e --- /dev/null +++ b/changelog.d/10137.misc @@ -0,0 +1 @@ +Add `parse_strings_from_args` for parsing an array from query parameters. diff --git a/mypy.ini b/mypy.ini index 8ba1b96311..1ab9001831 100644 --- a/mypy.ini +++ b/mypy.ini @@ -32,6 +32,7 @@ files = synapse/http/federation/matrix_federation_agent.py, synapse/http/federation/well_known_resolver.py, synapse/http/matrixfederationclient.py, + synapse/http/servlet.py, synapse/http/server.py, synapse/http/site.py, synapse/logging, diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 3f4f2411fc..d61563d39b 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -15,10 +15,12 @@ """ This module contains base REST classes for constructing REST servlets. """ import logging -from typing import Iterable, List, Optional, Union, overload +from typing import Dict, Iterable, List, Optional, overload from typing_extensions import Literal +from twisted.web.server import Request + from synapse.api.errors import Codes, SynapseError from synapse.util import json_decoder @@ -108,13 +110,66 @@ def parse_boolean_from_args(args, name, default=None, required=False): return default +@overload +def parse_bytes_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Literal[None] = None, + required: Literal[True] = True, +) -> bytes: + ... + + +@overload +def parse_bytes_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[bytes] = None, + required: bool = False, +) -> Optional[bytes]: + ... + + +def parse_bytes_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[bytes] = None, + required: bool = False, +) -> Optional[bytes]: + """ + Parse a string parameter as bytes from the request query string. + + Args: + args: A mapping of request args as bytes to a list of bytes (e.g. request.args). + name: the name of the query parameter. + default: value to use if the parameter is absent, + defaults to None. Must be bytes if encoding is None. + required: whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + Returns: + Bytes or the default value. + + Raises: + SynapseError if the parameter is absent and required. + """ + name_bytes = name.encode("ascii") + + if name_bytes in args: + return args[name_bytes][0] + elif required: + message = "Missing string query parameter %s" % (name,) + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + + return default + + def parse_string( - request, - name: Union[bytes, str], + request: Request, + name: str, default: Optional[str] = None, required: bool = False, allowed_values: Optional[Iterable[str]] = None, - encoding: Optional[str] = "ascii", + encoding: str = "ascii", ): """ Parse a string parameter from the request query string. @@ -125,66 +180,65 @@ def parse_string( Args: request: the twisted HTTP request. name: the name of the query parameter. - default: value to use if the parameter is absent, - defaults to None. Must be bytes if encoding is None. + default: value to use if the parameter is absent, defaults to None. required: whether to raise a 400 SynapseError if the parameter is absent, defaults to False. allowed_values: List of allowed values for the string, or None if any value is allowed, defaults to None. Must be the same type as name, if given. - encoding : The encoding to decode the string content with. + encoding: The encoding to decode the string content with. + Returns: - A string value or the default. Unicode if encoding - was given, bytes otherwise. + A string value or the default. Raises: SynapseError if the parameter is absent and required, or if the parameter is present, must be one of a list of allowed values and is not one of those allowed values. """ + args = request.args # type: Dict[bytes, List[bytes]] # type: ignore return parse_string_from_args( - request.args, name, default, required, allowed_values, encoding + args, name, default, required, allowed_values, encoding ) def _parse_string_value( - value: Union[str, bytes], + value: bytes, allowed_values: Optional[Iterable[str]], name: str, - encoding: Optional[str], -) -> Union[str, bytes]: - if encoding: - try: - value = value.decode(encoding) - except ValueError: - raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding)) + encoding: str, +) -> str: + try: + value_str = value.decode(encoding) + except ValueError: + raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding)) - if allowed_values is not None and value not in allowed_values: + if allowed_values is not None and value_str not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( name, ", ".join(repr(v) for v in allowed_values), ) raise SynapseError(400, message) else: - return value + return value_str @overload def parse_strings_from_args( - args: List[str], - name: Union[bytes, str], + args: Dict[bytes, List[bytes]], + name: str, default: Optional[List[str]] = None, - required: bool = False, + required: Literal[True] = True, allowed_values: Optional[Iterable[str]] = None, - encoding: Literal[None] = None, -) -> Optional[List[bytes]]: + encoding: str = "ascii", +) -> List[str]: ... @overload def parse_strings_from_args( - args: List[str], - name: Union[bytes, str], + args: Dict[bytes, List[bytes]], + name: str, default: Optional[List[str]] = None, required: bool = False, allowed_values: Optional[Iterable[str]] = None, @@ -194,46 +248,40 @@ def parse_strings_from_args( def parse_strings_from_args( - args: List[str], - name: Union[bytes, str], + args: Dict[bytes, List[bytes]], + name: str, default: Optional[List[str]] = None, required: bool = False, allowed_values: Optional[Iterable[str]] = None, - encoding: Optional[str] = "ascii", -) -> Optional[List[Union[bytes, str]]]: + encoding: str = "ascii", +) -> Optional[List[str]]: """ Parse a string parameter from the request query string list. - If encoding is not None, the content of the query param will be - decoded to Unicode using the encoding, otherwise it will be encoded + The content of the query param will be decoded to Unicode using the encoding. Args: - args: the twisted HTTP request.args list. + args: A mapping of request args as bytes to a list of bytes (e.g. request.args). name: the name of the query parameter. - default: value to use if the parameter is absent, - defaults to None. Must be bytes if encoding is None. - required : whether to raise a 400 SynapseError if the + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, defaults to False. - allowed_values (list[bytes|unicode]): List of allowed values for the - string, or None if any value is allowed, defaults to None. Must be - the same type as name, if given. + allowed_values: List of allowed values for the + string, or None if any value is allowed, defaults to None. encoding: The encoding to decode the string content with. Returns: - A string value or the default. Unicode if encoding - was given, bytes otherwise. + A string value or the default. Raises: SynapseError if the parameter is absent and required, or if the parameter is present, must be one of a list of allowed values and is not one of those allowed values. """ + name_bytes = name.encode("ascii") - if not isinstance(name, bytes): - name = name.encode("ascii") - - if name in args: - values = args[name] + if name_bytes in args: + values = args[name_bytes] return [ _parse_string_value(value, allowed_values, name=name, encoding=encoding) @@ -241,36 +289,30 @@ def parse_strings_from_args( ] else: if required: - message = "Missing string query parameter %r" % (name) + message = "Missing string query parameter %r" % (name,) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) - else: - if encoding and isinstance(default, bytes): - return default.decode(encoding) - - return default + return default def parse_string_from_args( - args: List[str], - name: Union[bytes, str], + args: Dict[bytes, List[bytes]], + name: str, default: Optional[str] = None, required: bool = False, allowed_values: Optional[Iterable[str]] = None, - encoding: Optional[str] = "ascii", -) -> Optional[Union[bytes, str]]: + encoding: str = "ascii", +) -> Optional[str]: """ Parse the string parameter from the request query string list and return the first result. - If encoding is not None, the content of the query param will be - decoded to Unicode using the encoding, otherwise it will be encoded + The content of the query param will be decoded to Unicode using the encoding. Args: - args: the twisted HTTP request.args list. + args: A mapping of request args as bytes to a list of bytes (e.g. request.args). name: the name of the query parameter. - default: value to use if the parameter is absent, - defaults to None. Must be bytes if encoding is None. + default: value to use if the parameter is absent, defaults to None. required: whether to raise a 400 SynapseError if the parameter is absent, defaults to False. allowed_values: List of allowed values for the @@ -279,8 +321,7 @@ def parse_string_from_args( encoding: The encoding to decode the string content with. Returns: - A string value or the default. Unicode if encoding - was given, bytes otherwise. + A string value or the default. Raises: SynapseError if the parameter is absent and required, or if the @@ -291,12 +332,15 @@ def parse_string_from_args( strings = parse_strings_from_args( args, name, - default=[default], + default=[default] if default is not None else None, required=required, allowed_values=allowed_values, encoding=encoding, ) + if strings is None: + return None + return strings[0] @@ -388,9 +432,8 @@ class RestServlet: def register(self, http_server): """ Register this servlet with the given HTTP server. """ - if hasattr(self, "PATTERNS"): - patterns = self.PATTERNS - + patterns = getattr(self, "PATTERNS", None) + if patterns: for method in ("GET", "PUT", "POST", "DELETE"): if hasattr(self, "on_%s" % (method,)): servlet_classname = self.__class__.__name__ diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f289ffe3d0..f0cddd2d2c 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -649,7 +649,7 @@ class RoomEventContextServlet(RestServlet): limit = parse_integer(request, "limit", default=10) # picking the API shape for symmetry with /messages - filter_str = parse_string(request, b"filter", encoding="utf-8") + filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) event_filter = Filter( diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 42e709ec14..f6be5f1020 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,7 +14,7 @@ import logging import re -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter @@ -25,6 +25,7 @@ from synapse.http import get_request_uri from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, + parse_bytes_from_args, parse_json_object_from_request, parse_string, ) @@ -437,9 +438,8 @@ class SsoRedirectServlet(RestServlet): finish_request(request) return - client_redirect_url = parse_string( - request, "redirectUrl", required=True, encoding=None - ) + args = request.args # type: Dict[bytes, List[bytes]] # type: ignore + client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True) sso_url = await self._sso_handler.handle_redirect_request( request, client_redirect_url, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 5a9c27f75f..122105854a 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -537,7 +537,7 @@ class RoomMessageListRestServlet(RestServlet): self.store, request, default_limit=10 ) as_client_event = b"raw" not in request.args - filter_str = parse_string(request, b"filter", encoding="utf-8") + filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) event_filter = Filter( @@ -652,7 +652,7 @@ class RoomEventContextServlet(RestServlet): limit = parse_integer(request, "limit", default=10) # picking the API shape for symmetry with /messages - filter_str = parse_string(request, b"filter", encoding="utf-8") + filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) event_filter = Filter( diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index b19cd8afc5..e52570cd8e 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -17,6 +17,7 @@ import logging from hashlib import sha256 from http import HTTPStatus from os import path +from typing import Dict, List import jinja2 from jinja2 import TemplateNotFound @@ -24,7 +25,7 @@ from jinja2 import TemplateNotFound from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.config import ConfigError from synapse.http.server import DirectServeHtmlResource, respond_with_html -from synapse.http.servlet import parse_string +from synapse.http.servlet import parse_bytes_from_args, parse_string from synapse.types import UserID # language to use for the templates. TODO: figure this out from Accept-Language @@ -116,7 +117,8 @@ class ConsentResource(DirectServeHtmlResource): has_consented = False public_version = username == "" if not public_version: - userhmac_bytes = parse_string(request, "h", required=True, encoding=None) + args = request.args # type: Dict[bytes, List[bytes]] + userhmac_bytes = parse_bytes_from_args(args, "h", required=True) self._check_hash(username, userhmac_bytes) @@ -152,7 +154,8 @@ class ConsentResource(DirectServeHtmlResource): """ version = parse_string(request, "v", required=True) username = parse_string(request, "u", required=True) - userhmac = parse_string(request, "h", required=True, encoding=None) + args = request.args # type: Dict[bytes, List[bytes]] + userhmac = parse_bytes_from_args(args, "h", required=True) self._check_hash(username, userhmac) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 024a105bf2..62dc4aae2d 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -14,13 +14,13 @@ # limitations under the License. import logging -from typing import IO, TYPE_CHECKING +from typing import IO, TYPE_CHECKING, Dict, List, Optional from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json -from synapse.http.servlet import parse_string +from synapse.http.servlet import parse_bytes_from_args from synapse.http.site import SynapseRequest from synapse.rest.media.v1.media_storage import SpamMediaException @@ -61,10 +61,11 @@ class UploadResource(DirectServeJsonResource): errcode=Codes.TOO_LARGE, ) - upload_name = parse_string(request, b"filename", encoding=None) - if upload_name: + args = request.args # type: Dict[bytes, List[bytes]] # type: ignore + upload_name_bytes = parse_bytes_from_args(args, "filename") + if upload_name_bytes: try: - upload_name = upload_name.decode("utf8") + upload_name = upload_name_bytes.decode("utf8") # type: Optional[str] except UnicodeDecodeError: raise SynapseError( msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400