Port over enough to get some sytests running on Python 3 (#3668)

pull/3724/head
Amber Brown 2018-08-20 23:54:49 +10:00 committed by GitHub
parent cf6f9a8b53
commit 324525f40c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 91 additions and 40 deletions

1
changelog.d/3668.misc Normal file
View File

@ -0,0 +1 @@
Port over enough to Python 3 to allow the sytests to start.

View File

@ -211,7 +211,7 @@ class Auth(object):
user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent",
default=[b""]
)[0]
)[0].decode('ascii', 'surrogateescape')
if user and access_token and ip_addr:
yield self.store.insert_client_ip(
user_id=user.to_string(),
@ -682,7 +682,7 @@ class Auth(object):
Returns:
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
query_params = request.args.get(b"access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
@ -698,7 +698,7 @@ class Auth(object):
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
unicode: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
@ -720,9 +720,9 @@ class Auth(object):
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(" ")
if parts[0] == "Bearer" and len(parts) == 2:
return parts[1]
parts = auth_headers[0].split(b" ")
if parts[0] == b"Bearer" and len(parts) == 2:
return parts[1].decode('ascii')
else:
raise AuthError(
token_not_found_http_status,
@ -738,7 +738,7 @@ class Auth(object):
errcode=Codes.MISSING_TOKEN
)
return query_params[0]
return query_params[0].decode('ascii')
@defer.inlineCallbacks
def check_in_room_or_world_readable(self, room_id, user_id):

View File

@ -72,7 +72,7 @@ class Ratelimiter(object):
return allowed, time_allowed
def prune_message_counts(self, time_now_s):
for user_id in self.message_counts.keys():
for user_id in list(self.message_counts.keys()):
message_count, time_start, msg_rate_hz = (
self.message_counts[user_id]
)

View File

@ -168,7 +168,8 @@ def setup_logging(config, use_worker_options=False):
if log_file:
# TODO: Customisable file size / backup count
handler = logging.handlers.RotatingFileHandler(
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3,
encoding='utf8'
)
def sighup(signum, stack):

View File

@ -29,7 +29,7 @@ def parse_integer(request, name, default=None, required=False):
Args:
request: the twisted HTTP request.
name (str): the name of the query parameter.
name (bytes/unicode): the name of the query parameter.
default (int|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
@ -46,6 +46,10 @@ def parse_integer(request, name, default=None, required=False):
def parse_integer_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
name = name.encode('ascii')
if name in args:
try:
return int(args[name][0])
@ -65,7 +69,7 @@ def parse_boolean(request, name, default=None, required=False):
Args:
request: the twisted HTTP request.
name (str): the name of the query parameter.
name (bytes/unicode): the name of the query parameter.
default (bool|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
@ -83,11 +87,15 @@ def parse_boolean(request, name, default=None, required=False):
def parse_boolean_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
name = name.encode('ascii')
if name in args:
try:
return {
"true": True,
"false": False,
b"true": True,
b"false": False,
}[args[name][0]]
except Exception:
message = (
@ -104,21 +112,29 @@ def parse_boolean_from_args(args, name, default=None, required=False):
def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"):
"""Parse a string parameter from the request query string.
allowed_values=None, param_type="string", encoding='ascii'):
"""
Parse a string parameter from the request query string.
If encoding is not None, the content of the query param will be
decoded to Unicode using the encoding, otherwise it will be encoded
Args:
request: the twisted HTTP request.
name (str): the name of the query parameter.
default (str|None): value to use if the parameter is absent, defaults
to None.
name (bytes/unicode): the name of the query parameter.
default (bytes/unicode|None): value to use if the parameter is absent,
defaults to None. Must be bytes if encoding is None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
allowed_values (list[str]): List of allowed values for the string,
or None if any value is allowed, defaults to None
allowed_values (list[bytes/unicode]): List of allowed values for the
string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given.
encoding: The encoding to decode the name to, and decode the string
content with.
Returns:
str|None: A string value or the default.
bytes/unicode|None: A string value or the default. Unicode if encoding
was given, bytes otherwise.
Raises:
SynapseError if the parameter is absent and required, or if the
@ -126,14 +142,22 @@ def parse_string(request, name, default=None, required=False,
is not one of those allowed values.
"""
return parse_string_from_args(
request.args, name, default, required, allowed_values, param_type,
request.args, name, default, required, allowed_values, param_type, encoding
)
def parse_string_from_args(args, name, default=None, required=False,
allowed_values=None, param_type="string"):
allowed_values=None, param_type="string", encoding='ascii'):
if not isinstance(name, bytes):
name = name.encode('ascii')
if name in args:
value = args[name][0]
if encoding:
value = value.decode(encoding)
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values)
@ -146,6 +170,10 @@ def parse_string_from_args(args, name, default=None, required=False,
message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else:
if encoding and isinstance(default, bytes):
return default.decode(encoding)
return default

View File

@ -235,7 +235,7 @@ class SynapseRequest(Request):
# need to decode as it could be raw utf-8 bytes
# from a IDN servname in an auth header
authenticated_entity = self.authenticated_entity
if authenticated_entity is not None:
if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
# ...or could be raw utf-8 bytes in the User-Agent header.
@ -328,7 +328,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string
self.server_version_string = server_version_string.encode('ascii')
def log(self, request):
pass

View File

@ -53,7 +53,7 @@ class HttpTransactionCache(object):
str: A transaction key
"""
token = self.auth.get_access_token_from_request(request)
return request.path + "/" + token
return request.path.decode('utf8') + "/" + token
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts

View File

@ -55,7 +55,7 @@ class UploadResource(Resource):
requester = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
content_length = request.getHeader(b"Content-Length").decode('ascii')
if content_length is None:
raise SynapseError(
msg="Request must specify a Content-Length", code=400
@ -66,10 +66,10 @@ class UploadResource(Resource):
code=413,
)
upload_name = parse_string(request, "filename")
upload_name = parse_string(request, b"filename", encoding=None)
if upload_name:
try:
upload_name = upload_name.decode('UTF-8')
upload_name = upload_name.decode('utf8')
except UnicodeDecodeError:
raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
@ -78,8 +78,8 @@ class UploadResource(Resource):
headers = request.requestHeaders
if headers.hasHeader("Content-Type"):
media_type = headers.getRawHeaders(b"Content-Type")[0]
if headers.hasHeader(b"Content-Type"):
media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii')
else:
raise SynapseError(
msg="Upload request missing 'Content-Type'",

View File

@ -38,4 +38,4 @@ else:
return os.urandom(nbytes)
def token_hex(self, nbytes=32):
return binascii.hexlify(self.token_bytes(nbytes))
return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii')

View File

@ -20,6 +20,8 @@ import time
from functools import wraps
from inspect import getcallargs
from six import PY3
_TIME_FUNC_ID = 0
@ -28,8 +30,12 @@ def _log_debug_as_f(f, msg, msg_args):
logger = logging.getLogger(name)
if logger.isEnabledFor(logging.DEBUG):
lineno = f.func_code.co_firstlineno
pathname = f.func_code.co_filename
if PY3:
lineno = f.__code__.co_firstlineno
pathname = f.__code__.co_filename
else:
lineno = f.func_code.co_firstlineno
pathname = f.func_code.co_filename
record = logging.LogRecord(
name=name,

View File

@ -16,6 +16,7 @@
import random
import string
from six import PY3
from six.moves import range
_string_with_symbols = (
@ -34,6 +35,17 @@ def random_string_with_symbols(length):
def is_ascii(s):
if PY3:
if isinstance(s, bytes):
try:
s.decode('ascii').encode('ascii')
except UnicodeDecodeError:
return False
except UnicodeEncodeError:
return False
return True
try:
s.encode("ascii")
except UnicodeEncodeError:
@ -49,6 +61,9 @@ def to_ascii(s):
If given None then will return None.
"""
if PY3:
return s
if s is None:
return None

View File

@ -30,7 +30,7 @@ def get_version_string(module):
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
stderr=null,
cwd=cwd,
).strip()
).strip().decode('ascii')
git_branch = "b=" + git_branch
except subprocess.CalledProcessError:
git_branch = ""
@ -40,7 +40,7 @@ def get_version_string(module):
['git', 'describe', '--exact-match'],
stderr=null,
cwd=cwd,
).strip()
).strip().decode('ascii')
git_tag = "t=" + git_tag
except subprocess.CalledProcessError:
git_tag = ""
@ -50,7 +50,7 @@ def get_version_string(module):
['git', 'rev-parse', '--short', 'HEAD'],
stderr=null,
cwd=cwd,
).strip()
).strip().decode('ascii')
except subprocess.CalledProcessError:
git_commit = ""
@ -60,7 +60,7 @@ def get_version_string(module):
['git', 'describe', '--dirty=' + dirty_string],
stderr=null,
cwd=cwd,
).strip().endswith(dirty_string)
).strip().decode('ascii').endswith(dirty_string)
git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError:
@ -77,8 +77,8 @@ def get_version_string(module):
"%s (%s)" % (
module.__version__, git_version,
)
).encode("ascii")
)
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
return module.__version__.encode("ascii")
return module.__version__