258 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			258 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Python
		
	
	
| # Copyright 2014-2016 OpenMarket Ltd
 | |
| # Copyright 2020 The Matrix.org Foundation C.I.C.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| import itertools
 | |
| import re
 | |
| import secrets
 | |
| import string
 | |
| from typing import Any, Iterable, Optional, Tuple
 | |
| 
 | |
| from netaddr import valid_ipv6
 | |
| 
 | |
| from synapse.api.errors import Codes, SynapseError
 | |
| 
 | |
| _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
 | |
| 
 | |
| # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
 | |
| CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
 | |
| 
 | |
| # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
 | |
| # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
 | |
| # says "there is no grammar for media ids"
 | |
| #
 | |
| # The server_name part of this is purposely lax: use parse_and_validate_mxc for
 | |
| # additional validation.
 | |
| #
 | |
| MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
 | |
| 
 | |
| 
 | |
| def random_string(length: int) -> str:
 | |
|     """Generate a cryptographically secure string of random letters.
 | |
| 
 | |
|     Drawn from the characters: `a-z` and `A-Z`
 | |
|     """
 | |
|     return "".join(secrets.choice(string.ascii_letters) for _ in range(length))
 | |
| 
 | |
| 
 | |
| def random_string_with_symbols(length: int) -> str:
 | |
|     """Generate a cryptographically secure string of random letters/numbers/symbols.
 | |
| 
 | |
|     Drawn from the characters: `a-z`, `A-Z`, `0-9`, and `.,;:^&*-_+=#~@`
 | |
|     """
 | |
|     return "".join(secrets.choice(_string_with_symbols) for _ in range(length))
 | |
| 
 | |
| 
 | |
| def is_ascii(s: bytes) -> bool:
 | |
|     try:
 | |
|         s.decode("ascii").encode("ascii")
 | |
|     except UnicodeError:
 | |
|         return False
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def assert_valid_client_secret(client_secret: str) -> None:
 | |
|     """Validate that a given string matches the client_secret defined by the spec"""
 | |
|     if (
 | |
|         len(client_secret) <= 0
 | |
|         or len(client_secret) > 255
 | |
|         or CLIENT_SECRET_REGEX.match(client_secret) is None
 | |
|     ):
 | |
|         raise SynapseError(
 | |
|             400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
 | |
|         )
 | |
| 
 | |
| 
 | |
| def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
 | |
|     """Split a server name into host/port parts.
 | |
| 
 | |
|     Args:
 | |
|         server_name: server name to parse
 | |
| 
 | |
|     Returns:
 | |
|         host/port parts.
 | |
| 
 | |
|     Raises:
 | |
|         ValueError if the server name could not be parsed.
 | |
|     """
 | |
|     try:
 | |
|         if server_name and server_name[-1] == "]":
 | |
|             # ipv6 literal, hopefully
 | |
|             return server_name, None
 | |
| 
 | |
|         domain_port = server_name.rsplit(":", 1)
 | |
|         domain = domain_port[0]
 | |
|         port = int(domain_port[1]) if domain_port[1:] else None
 | |
|         return domain, port
 | |
|     except Exception:
 | |
|         raise ValueError("Invalid server name '%s'" % server_name)
 | |
| 
 | |
| 
 | |
| # An approximation of the domain name syntax in RFC 1035, section 2.3.1.
 | |
| # NB: "\Z" is not equivalent to "$".
 | |
| #     The latter will match the position before a "\n" at the end of a string.
 | |
| VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*\\Z")
 | |
| 
 | |
| 
 | |
| def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
 | |
|     """Split a server name into host/port parts and do some basic validation.
 | |
| 
 | |
|     Args:
 | |
|         server_name: server name to parse
 | |
| 
 | |
|     Returns:
 | |
|         host/port parts.
 | |
| 
 | |
|     Raises:
 | |
|         ValueError if the server name could not be parsed.
 | |
|     """
 | |
|     host, port = parse_server_name(server_name)
 | |
| 
 | |
|     # these tests don't need to be bulletproof as we'll find out soon enough
 | |
|     # if somebody is giving us invalid data. What we *do* need is to be sure
 | |
|     # that nobody is sneaking IP literals in that look like hostnames, etc.
 | |
| 
 | |
|     # look for ipv6 literals
 | |
|     if host and host[0] == "[":
 | |
|         if host[-1] != "]":
 | |
|             raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
 | |
| 
 | |
|         # valid_ipv6 raises when given an empty string
 | |
|         ipv6_address = host[1:-1]
 | |
|         if not ipv6_address or not valid_ipv6(ipv6_address):
 | |
|             raise ValueError(
 | |
|                 "Server name '%s' is not a valid IPv6 address" % (server_name,)
 | |
|             )
 | |
|     elif not VALID_HOST_REGEX.match(host):
 | |
|         raise ValueError("Server name '%s' has an invalid format" % (server_name,))
 | |
| 
 | |
|     return host, port
 | |
| 
 | |
| 
 | |
| def valid_id_server_location(id_server: str) -> bool:
 | |
|     """Check whether an identity server location, such as the one passed as the
 | |
|     `id_server` parameter to `/_matrix/client/r0/account/3pid/bind`, is valid.
 | |
| 
 | |
|     A valid identity server location consists of a valid hostname and optional
 | |
|     port number, optionally followed by any number of `/` delimited path
 | |
|     components, without any fragment or query string parts.
 | |
| 
 | |
|     Args:
 | |
|         id_server: identity server location string to validate
 | |
| 
 | |
|     Returns:
 | |
|         True if valid, False otherwise.
 | |
|     """
 | |
| 
 | |
|     components = id_server.split("/", 1)
 | |
| 
 | |
|     host = components[0]
 | |
| 
 | |
|     try:
 | |
|         parse_and_validate_server_name(host)
 | |
|     except ValueError:
 | |
|         return False
 | |
| 
 | |
|     if len(components) < 2:
 | |
|         # no path
 | |
|         return True
 | |
| 
 | |
|     path = components[1]
 | |
|     return "#" not in path and "?" not in path
 | |
| 
 | |
| 
 | |
| def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
 | |
|     """Parse the given string as an MXC URI
 | |
| 
 | |
|     Checks that the "server name" part is a valid server name
 | |
| 
 | |
|     Args:
 | |
|         mxc: the (alleged) MXC URI to be checked
 | |
|     Returns:
 | |
|         hostname, port, media id
 | |
|     Raises:
 | |
|         ValueError if the URI cannot be parsed
 | |
|     """
 | |
|     m = MXC_REGEX.match(mxc)
 | |
|     if not m:
 | |
|         raise ValueError("mxc URI %r did not match expected format" % (mxc,))
 | |
|     server_name = m.group(1)
 | |
|     media_id = m.group(2)
 | |
|     host, port = parse_and_validate_server_name(server_name)
 | |
|     return host, port, media_id
 | |
| 
 | |
| 
 | |
| def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
 | |
|     """If iterable has maxitems or fewer, return the stringification of a list
 | |
|     containing those items.
 | |
| 
 | |
|     Otherwise, return the stringification of a list with the first maxitems items,
 | |
|     followed by "...".
 | |
| 
 | |
|     Args:
 | |
|         iterable: iterable to truncate
 | |
|         maxitems: number of items to return before truncating
 | |
|     """
 | |
| 
 | |
|     items = list(itertools.islice(iterable, maxitems + 1))
 | |
|     if len(items) <= maxitems:
 | |
|         return str(items)
 | |
|     return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
 | |
| 
 | |
| 
 | |
| def strtobool(val: str) -> bool:
 | |
|     """Convert a string representation of truth to True or False
 | |
| 
 | |
|     True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
 | |
|     are 'n', 'no', 'f', 'false', 'off', and '0'.  Raises ValueError if
 | |
|     'val' is anything else.
 | |
| 
 | |
|     This is lifted from distutils.util.strtobool, with the exception that it actually
 | |
|     returns a bool, rather than an int.
 | |
|     """
 | |
|     val = val.lower()
 | |
|     if val in ("y", "yes", "t", "true", "on", "1"):
 | |
|         return True
 | |
|     elif val in ("n", "no", "f", "false", "off", "0"):
 | |
|         return False
 | |
|     else:
 | |
|         raise ValueError("invalid truth value %r" % (val,))
 | |
| 
 | |
| 
 | |
| _BASE62 = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
 | |
| 
 | |
| 
 | |
| def base62_encode(num: int, minwidth: int = 1) -> str:
 | |
|     """Encode a number using base62
 | |
| 
 | |
|     Args:
 | |
|         num: number to be encoded
 | |
|         minwidth: width to pad to, if the number is small
 | |
|     """
 | |
|     res = ""
 | |
|     while num:
 | |
|         num, rem = divmod(num, 62)
 | |
|         res = _BASE62[rem] + res
 | |
| 
 | |
|     # pad to minimum width
 | |
|     pad = "0" * (minwidth - len(res))
 | |
|     return pad + res
 | |
| 
 | |
| 
 | |
| def non_null_str_or_none(val: Any) -> Optional[str]:
 | |
|     """Check that the arg is a string containing no null (U+0000) codepoints.
 | |
| 
 | |
|     If so, returns the given string unmodified; otherwise, returns None.
 | |
|     """
 | |
|     return val if isinstance(val, str) and "\u0000" not in val else None
 |