223 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			223 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Python
		
	
	
| # Copyright 2014-2016 OpenMarket Ltd
 | |
| # Copyright 2020 The Matrix.org Foundation C.I.C.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| import itertools
 | |
| import random
 | |
| import re
 | |
| import string
 | |
| from collections.abc import Iterable
 | |
| from typing import Optional, Tuple
 | |
| 
 | |
| from synapse.api.errors import Codes, SynapseError
 | |
| 
 | |
| _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
 | |
| 
 | |
| # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
 | |
| CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
 | |
| 
 | |
| # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
 | |
| # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
 | |
| # says "there is no grammar for media ids"
 | |
| #
 | |
| # The server_name part of this is purposely lax: use parse_and_validate_mxc for
 | |
| # additional validation.
 | |
| #
 | |
| MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
 | |
| 
 | |
| # random_string and random_string_with_symbols are used for a range of things,
 | |
| # some cryptographically important, some less so. We use SystemRandom to make sure
 | |
| # we get cryptographically-secure randoms.
 | |
| rand = random.SystemRandom()
 | |
| 
 | |
| 
 | |
| def random_string(length: int) -> str:
 | |
|     return "".join(rand.choice(string.ascii_letters) for _ in range(length))
 | |
| 
 | |
| 
 | |
| def random_string_with_symbols(length: int) -> str:
 | |
|     return "".join(rand.choice(_string_with_symbols) for _ in range(length))
 | |
| 
 | |
| 
 | |
| def is_ascii(s: bytes) -> bool:
 | |
|     try:
 | |
|         s.decode("ascii").encode("ascii")
 | |
|     except UnicodeDecodeError:
 | |
|         return False
 | |
|     except UnicodeEncodeError:
 | |
|         return False
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def assert_valid_client_secret(client_secret: str) -> None:
 | |
|     """Validate that a given string matches the client_secret defined by the spec"""
 | |
|     if (
 | |
|         len(client_secret) <= 0
 | |
|         or len(client_secret) > 255
 | |
|         or CLIENT_SECRET_REGEX.match(client_secret) is None
 | |
|     ):
 | |
|         raise SynapseError(
 | |
|             400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
 | |
|         )
 | |
| 
 | |
| 
 | |
| def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
 | |
|     """Split a server name into host/port parts.
 | |
| 
 | |
|     Args:
 | |
|         server_name: server name to parse
 | |
| 
 | |
|     Returns:
 | |
|         host/port parts.
 | |
| 
 | |
|     Raises:
 | |
|         ValueError if the server name could not be parsed.
 | |
|     """
 | |
|     try:
 | |
|         if server_name[-1] == "]":
 | |
|             # ipv6 literal, hopefully
 | |
|             return server_name, None
 | |
| 
 | |
|         domain_port = server_name.rsplit(":", 1)
 | |
|         domain = domain_port[0]
 | |
|         port = int(domain_port[1]) if domain_port[1:] else None
 | |
|         return domain, port
 | |
|     except Exception:
 | |
|         raise ValueError("Invalid server name '%s'" % server_name)
 | |
| 
 | |
| 
 | |
| VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
 | |
| 
 | |
| 
 | |
| def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
 | |
|     """Split a server name into host/port parts and do some basic validation.
 | |
| 
 | |
|     Args:
 | |
|         server_name: server name to parse
 | |
| 
 | |
|     Returns:
 | |
|         host/port parts.
 | |
| 
 | |
|     Raises:
 | |
|         ValueError if the server name could not be parsed.
 | |
|     """
 | |
|     host, port = parse_server_name(server_name)
 | |
| 
 | |
|     # these tests don't need to be bulletproof as we'll find out soon enough
 | |
|     # if somebody is giving us invalid data. What we *do* need is to be sure
 | |
|     # that nobody is sneaking IP literals in that look like hostnames, etc.
 | |
| 
 | |
|     # look for ipv6 literals
 | |
|     if host[0] == "[":
 | |
|         if host[-1] != "]":
 | |
|             raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
 | |
|         return host, port
 | |
| 
 | |
|     # otherwise it should only be alphanumerics.
 | |
|     if not VALID_HOST_REGEX.match(host):
 | |
|         raise ValueError(
 | |
|             "Server name '%s' contains invalid characters" % (server_name,)
 | |
|         )
 | |
| 
 | |
|     return host, port
 | |
| 
 | |
| 
 | |
| def valid_id_server_location(id_server: str) -> bool:
 | |
|     """Check whether an identity server location, such as the one passed as the
 | |
|     `id_server` parameter to `/_matrix/client/r0/account/3pid/bind`, is valid.
 | |
| 
 | |
|     A valid identity server location consists of a valid hostname and optional
 | |
|     port number, optionally followed by any number of `/` delimited path
 | |
|     components, without any fragment or query string parts.
 | |
| 
 | |
|     Args:
 | |
|         id_server: identity server location string to validate
 | |
| 
 | |
|     Returns:
 | |
|         True if valid, False otherwise.
 | |
|     """
 | |
| 
 | |
|     components = id_server.split("/", 1)
 | |
| 
 | |
|     host = components[0]
 | |
| 
 | |
|     try:
 | |
|         parse_and_validate_server_name(host)
 | |
|     except ValueError:
 | |
|         return False
 | |
| 
 | |
|     if len(components) < 2:
 | |
|         # no path
 | |
|         return True
 | |
| 
 | |
|     path = components[1]
 | |
|     return "#" not in path and "?" not in path
 | |
| 
 | |
| 
 | |
| def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
 | |
|     """Parse the given string as an MXC URI
 | |
| 
 | |
|     Checks that the "server name" part is a valid server name
 | |
| 
 | |
|     Args:
 | |
|         mxc: the (alleged) MXC URI to be checked
 | |
|     Returns:
 | |
|         hostname, port, media id
 | |
|     Raises:
 | |
|         ValueError if the URI cannot be parsed
 | |
|     """
 | |
|     m = MXC_REGEX.match(mxc)
 | |
|     if not m:
 | |
|         raise ValueError("mxc URI %r did not match expected format" % (mxc,))
 | |
|     server_name = m.group(1)
 | |
|     media_id = m.group(2)
 | |
|     host, port = parse_and_validate_server_name(server_name)
 | |
|     return host, port, media_id
 | |
| 
 | |
| 
 | |
| def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
 | |
|     """If iterable has maxitems or fewer, return the stringification of a list
 | |
|     containing those items.
 | |
| 
 | |
|     Otherwise, return the stringification of a a list with the first maxitems items,
 | |
|     followed by "...".
 | |
| 
 | |
|     Args:
 | |
|         iterable: iterable to truncate
 | |
|         maxitems: number of items to return before truncating
 | |
|     """
 | |
| 
 | |
|     items = list(itertools.islice(iterable, maxitems + 1))
 | |
|     if len(items) <= maxitems:
 | |
|         return str(items)
 | |
|     return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
 | |
| 
 | |
| 
 | |
| def strtobool(val: str) -> bool:
 | |
|     """Convert a string representation of truth to True or False
 | |
| 
 | |
|     True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
 | |
|     are 'n', 'no', 'f', 'false', 'off', and '0'.  Raises ValueError if
 | |
|     'val' is anything else.
 | |
| 
 | |
|     This is lifted from distutils.util.strtobool, with the exception that it actually
 | |
|     returns a bool, rather than an int.
 | |
|     """
 | |
|     val = val.lower()
 | |
|     if val in ("y", "yes", "t", "true", "on", "1"):
 | |
|         return True
 | |
|     elif val in ("n", "no", "f", "false", "off", "0"):
 | |
|         return False
 | |
|     else:
 | |
|         raise ValueError("invalid truth value %r" % (val,))
 |