Use parse_{int,str} and assert from http.servlet

parse_integer and parse_string can take a request and raise errors
in case we have wrong or missing params.
This PR tries to use them more to deduplicate some code and make it
better readable
pull/3534/head
Krombel 2018-07-13 21:40:14 +02:00
parent 2aba1f549c
commit 32fd6910d0
14 changed files with 90 additions and 154 deletions

View File

@ -22,7 +22,12 @@ from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import (
assert_params_in_request,
parse_json_object_from_request,
parse_integer,
parse_string
)
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -98,16 +103,8 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
before_ts = request.args.get("before_ts", None) before_ts = parse_integer(request, "before_ts", required=True)
if not before_ts: logger.info("before_ts: %r", before_ts)
raise SynapseError(400, "Missing 'before_ts' arg")
logger.info("before_ts: %r", before_ts[0])
try:
before_ts = int(before_ts[0])
except Exception:
raise SynapseError(400, "Invalid 'before_ts' arg")
ret = yield self.media_repository.delete_old_remote_media(before_ts) ret = yield self.media_repository.delete_old_remote_media(before_ts)
@ -300,10 +297,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert_params_in_request(content, ["new_room_user_id"])
new_room_user_id = content.get("new_room_user_id") new_room_user_id = content["new_room_user_id"]
if not new_room_user_id:
raise SynapseError(400, "Please provide field `new_room_user_id`")
room_creator_requester = create_requester(new_room_user_id) room_creator_requester = create_requester(new_room_user_id)
@ -464,9 +459,8 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
params = parse_json_object_from_request(request) params = parse_json_object_from_request(request)
assert_params_in_request(params, ["new_password"])
new_password = params['new_password'] new_password = params['new_password']
if not new_password:
raise SynapseError(400, "Missing 'new_password' arg")
logger.info("new_password: %r", new_password) logger.info("new_password: %r", new_password)
@ -514,12 +508,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Can only users a local user") raise SynapseError(400, "Can only users a local user")
order = "name" # order by name in user table order = "name" # order by name in user table
start = request.args.get("start")[0] start = parse_integer(request, "start", required=True)
limit = request.args.get("limit")[0] limit = parse_integer(request, "limit", required=True)
if not limit:
raise SynapseError(400, "Missing 'limit' arg")
if not start:
raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start) logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate( ret = yield self.handlers.admin_handler.get_users_paginate(
@ -551,12 +542,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
order = "name" # order by name in user table order = "name" # order by name in user table
params = parse_json_object_from_request(request) params = parse_json_object_from_request(request)
assert_params_in_request(params, ["limit", "start"])
limit = params['limit'] limit = params['limit']
start = params['start'] start = params['start']
if not limit:
raise SynapseError(400, "Missing 'limit' arg")
if not start:
raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start) logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate( ret = yield self.handlers.admin_handler.get_users_paginate(
@ -604,10 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user") raise SynapseError(400, "Can only users a local user")
term = request.args.get("term")[0] term = parse_string(request, "term", required=True)
if not term:
raise SynapseError(400, "Missing 'term' arg")
logger.info("term: %s ", term) logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users( ret = yield self.handlers.admin_handler.search_users(

View File

@ -18,8 +18,8 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import assert_params_in_request, parse_json_object_from_request
from synapse.types import RoomAlias from synapse.types import RoomAlias
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -52,15 +52,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_alias): def on_PUT(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
if "room_id" not in content: if "room_id" not in content:
raise SynapseError(400, "Missing room_id key", raise SynapseError(400, 'Missing params: ["room_id"]',
errcode=Codes.BAD_JSON) errcode=Codes.BAD_JSON)
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
room_alias = RoomAlias.from_string(room_alias)
logger.debug("Got room name: %s", room_alias.to_string()) logger.debug("Got room name: %s", room_alias.to_string())
room_id = content["room_id"] room_id = content["room_id"]

View File

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.http.servlet import parse_boolean
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -33,7 +34,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
include_archived = request.args.get("archived", None) == ["true"] include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms( content = yield self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,

View File

@ -21,7 +21,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.http.servlet import parse_json_value_from_request from synapse.http.servlet import parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP
@ -75,11 +75,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
before = request.args.get("before", None) before = parse_string(request, "before")
if before: if before:
before = _namespaced_rule_id(spec, before[0]) before = _namespaced_rule_id(spec, before[0])
after = request.args.get("after", None) after = parse_string(request, "after")
if after: if after:
after = _namespaced_rule_id(spec, after[0]) after = _namespaced_rule_id(spec, after[0])

View File

@ -21,6 +21,7 @@ from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_request,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
@ -91,15 +92,11 @@ class PushersSetRestServlet(ClientV1RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
reqd = ['kind', 'app_id', 'app_display_name', assert_params_in_request(
'device_display_name', 'pushkey', 'lang', 'data'] content,
missing = [] ['kind', 'app_id', 'app_display_name',
for i in reqd: 'device_display_name', 'pushkey', 'lang', 'data']
if i not in content: )
missing.append(i)
if len(missing):
raise SynapseError(400, "Missing parameters: " + ','.join(missing),
errcode=Codes.MISSING_PARAM)
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
logger.debug("Got pushers request with body: %r", content) logger.debug("Got pushers request with body: %r", content)
@ -148,7 +145,7 @@ class PushersRemoveRestServlet(RestServlet):
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>" SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs): def __init__(self, hs):
super(RestServlet, self).__init__() super(PushersRemoveRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.auth = hs.get_auth() self.auth = hs.get_auth()

View File

@ -18,15 +18,13 @@ import hmac
import logging import logging
from hashlib import sha1 from hashlib import sha1
from six import string_types
from twisted.internet import defer from twisted.internet import defer
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.api.auth import get_access_token_from_request from synapse.api.auth import get_access_token_from_request
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import assert_params_in_request, parse_json_object_from_request
from synapse.types import create_requester from synapse.types import create_requester
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -124,8 +122,7 @@ class RegisterRestServlet(ClientV1RestServlet):
session = (register_json["session"] session = (register_json["session"]
if "session" in register_json else None) if "session" in register_json else None)
login_type = None login_type = None
if "type" not in register_json: assert_params_in_request(register_json, ["type"])
raise SynapseError(400, "Missing 'type' key.")
try: try:
login_type = register_json["type"] login_type = register_json["type"]
@ -312,9 +309,7 @@ class RegisterRestServlet(ClientV1RestServlet):
def _do_app_service(self, request, register_json, session): def _do_app_service(self, request, register_json, session):
as_token = get_access_token_from_request(request) as_token = get_access_token_from_request(request)
if "user" not in register_json: assert_params_in_request(register_json, ["user"])
raise SynapseError(400, "Expected 'user' key.")
user_localpart = register_json["user"].encode("utf-8") user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
@ -331,12 +326,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session): def _do_shared_secret(self, request, register_json, session):
if not isinstance(register_json.get("mac", None), string_types): assert_params_in_request(register_json, ["mac", "user", "password"])
raise SynapseError(400, "Expected mac.")
if not isinstance(register_json.get("user", None), string_types):
raise SynapseError(400, "Expected 'user' key.")
if not isinstance(register_json.get("password", None), string_types):
raise SynapseError(400, "Expected 'password' key.")
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled") raise SynapseError(400, "Shared secret registration is not enabled")
@ -419,11 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_create(self, requester, user_json): def _do_create(self, requester, user_json):
if "localpart" not in user_json: assert_params_in_request(user_json, ["localpart", "displayname"])
raise SynapseError(400, "Expected 'localpart' key.")
if "displayname" not in user_json:
raise SynapseError(400, "Expected 'displayname' key.")
localpart = user_json["localpart"].encode("utf-8") localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8") displayname = user_json["displayname"].encode("utf-8")

View File

@ -28,6 +28,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2, serialize_event from synapse.events.utils import format_event_for_client_v2, serialize_event
from synapse.http.servlet import ( from synapse.http.servlet import (
assert_params_in_request,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
@ -435,7 +436,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10, request, default_limit=10,
) )
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None) filter_bytes = parse_string(request, "filter")
if filter_bytes: if filter_bytes:
filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8") filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8")
event_filter = Filter(json.loads(filter_json)) event_filter = Filter(json.loads(filter_json))
@ -530,7 +531,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
limit = int(request.args.get("limit", [10])[0]) limit = parse_integer(request, "limit", default=10)
results = yield self.handlers.room_context_handler.get_event_context( results = yield self.handlers.room_context_handler.get_event_context(
requester.user, requester.user,
@ -636,8 +637,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
target = requester.user target = requester.user
if membership_action in ["invite", "ban", "unban", "kick"]: if membership_action in ["invite", "ban", "unban", "kick"]:
if "user_id" not in content: assert_params_in_request(content, ["user_id"])
raise SynapseError(400, "Missing user_id key.")
target = UserID.from_string(content["user_id"]) target = UserID.from_string(content["user_id"])
event_content = None event_content = None
@ -764,7 +764,7 @@ class SearchRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
batch = request.args.get("next_batch", [None])[0] batch = parse_string(request, "next_batch")
results = yield self.handlers.search_handler.search( results = yield self.handlers.search_handler.search(
requester.user, requester.user,
content, content,

View File

@ -160,11 +160,10 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user_id user_id = threepid_user_id
else: else:
logger.error("Auth succeeded but no known type!", result.keys()) logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN) raise SynapseError(500, "", Codes.UNKNOWN)
if 'new_password' not in params: assert_params_in_request(params, ["new_password"])
raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password'] new_password = params['new_password']
yield self._set_password_handler.set_password( yield self._set_password_handler.set_password(
@ -229,15 +228,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(
required = ['id_server', 'client_secret', 'email', 'send_attempt'] body,
absent = [] ['id_server', 'client_secret', 'email', 'send_attempt'],
for k in required: )
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
if not check_3pid_allowed(self.hs, "email", body['email']): if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError( raise SynapseError(
@ -267,18 +261,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(body, [
required = [
'id_server', 'client_secret', 'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt', 'country', 'phone_number', 'send_attempt',
] ])
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
@ -373,15 +359,7 @@ class ThreepidDeleteRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(body, ['medium', 'address'])
required = ['medium', 'address']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()

View File

@ -18,14 +18,18 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api import errors from synapse.api import errors
from synapse.http import servlet from synapse.http.servlet import (
assert_params_in_request,
parse_json_object_from_request,
RestServlet
)
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet): class DevicesRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices$", v2_alpha=False) PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
def __init__(self, hs): def __init__(self, hs):
@ -47,7 +51,7 @@ class DevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {"devices": devices})) defer.returnValue((200, {"devices": devices}))
class DeleteDevicesRestServlet(servlet.RestServlet): class DeleteDevicesRestServlet(RestServlet):
""" """
API for bulk deletion of devices. Accepts a JSON object with a devices API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth. key which lists the device_ids to delete. Requires user interactive auth.
@ -67,19 +71,17 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
try: try:
body = servlet.parse_json_object_from_request(request) body = parse_json_object_from_request(request)
except errors.SynapseError as e: except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON: if e.errcode == errors.Codes.NOT_JSON:
# deal with older clients which didn't pass a J*DELETESON dict # DELETE
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict # the same as those that pass an empty dict
body = {} body = {}
else: else:
raise e raise e
if 'devices' not in body: assert_params_in_request(body, ["devices"])
raise errors.SynapseError(
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
)
yield self.auth_handler.validate_user_via_ui_auth( yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request), requester, body, self.hs.get_ip_from_request(request),
@ -92,7 +94,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class DeviceRestServlet(servlet.RestServlet): class DeviceRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False) PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
def __init__(self, hs): def __init__(self, hs):
@ -121,7 +123,7 @@ class DeviceRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
try: try:
body = servlet.parse_json_object_from_request(request) body = parse_json_object_from_request(request)
except errors.SynapseError as e: except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON: if e.errcode == errors.Codes.NOT_JSON:
@ -144,7 +146,7 @@ class DeviceRestServlet(servlet.RestServlet):
def on_PUT(self, request, device_id): def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
body = servlet.parse_json_object_from_request(request) body = parse_json_object_from_request(request)
yield self.device_handler.update_device( yield self.device_handler.update_device(
requester.user.to_string(), requester.user.to_string(),
device_id, device_id,

View File

@ -387,9 +387,7 @@ class RegisterRestServlet(RestServlet):
add_msisdn = False add_msisdn = False
else: else:
# NB: This may be from the auth handler and NOT from the POST # NB: This may be from the auth handler and NOT from the POST
if 'password' not in params: assert_params_in_request(params, ["password"])
raise SynapseError(400, "Missing password.",
Codes.MISSING_PARAM)
desired_username = params.get("username", None) desired_username = params.get("username", None)
new_password = params.get("password", None) new_password = params.get("password", None)
@ -566,11 +564,14 @@ class RegisterRestServlet(RestServlet):
Returns: Returns:
defer.Deferred: defer.Deferred:
""" """
reqd = ('medium', 'address', 'validated_at') try:
if any(x not in threepid for x in reqd): assert_params_in_request(threepid, ['medium', 'address', 'validated_at'])
# This will only happen if the ID server returns a malformed response except SynapseError as ex:
logger.info("Can't add incomplete 3pid") if ex.errcode == Codes.MISSING_PARAM:
defer.returnValue() # This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
defer.returnValue(None)
raise
yield self.auth_handler.add_threepid( yield self.auth_handler.add_threepid(
user_id, user_id,

View File

@ -14,6 +14,8 @@
from pydenticon import Generator from pydenticon import Generator
from synapse.http.servlet import parse_integer
from twisted.web.resource import Resource from twisted.web.resource import Resource
FOREGROUND = [ FOREGROUND = [
@ -56,8 +58,8 @@ class IdenticonResource(Resource):
def render_GET(self, request): def render_GET(self, request):
name = "/".join(request.postpath) name = "/".join(request.postpath)
width = int(request.args.get("width", [96])[0]) width = parse_integer(request, "width", default=96)
height = int(request.args.get("height", [96])[0]) height = parse_integer(request, "height", default=96)
identicon_bytes = self.generate_identicon(name, width, height) identicon_bytes = self.generate_identicon(name, width, height)
request.setHeader(b"Content-Type", b"image/png") request.setHeader(b"Content-Type", b"image/png")
request.setHeader( request.setHeader(

View File

@ -40,6 +40,7 @@ from synapse.http.server import (
respond_with_json_bytes, respond_with_json_bytes,
wrap_json_request_handler, wrap_json_request_handler,
) )
from synapse.http.servlet import parse_integer, parse_string
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@ -96,9 +97,9 @@ class PreviewUrlResource(Resource):
# XXX: if get_user_by_req fails, what should we do in an async render? # XXX: if get_user_by_req fails, what should we do in an async render?
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
url = request.args.get("url")[0] url = parse_string(request, "url")
if "ts" in request.args: if "ts" in request.args:
ts = int(request.args.get("ts")[0]) ts = parse_integer(request, "ts")
else: else:
ts = self.clock.time_msec() ts = self.clock.time_msec()

View File

@ -21,6 +21,7 @@ from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import respond_with_json, wrap_json_request_handler from synapse.http.server import respond_with_json, wrap_json_request_handler
from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -65,10 +66,10 @@ class UploadResource(Resource):
code=413, code=413,
) )
upload_name = request.args.get("filename", None) upload_name = parse_string(request, "filename")
if upload_name: if upload_name:
try: try:
upload_name = upload_name[0].decode('UTF-8') upload_name = upload_name.decode('UTF-8')
except UnicodeDecodeError: except UnicodeDecodeError:
raise SynapseError( raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), msg="Invalid UTF-8 filename parameter: %r" % (upload_name),

View File

@ -16,6 +16,7 @@
import logging import logging
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
from synapse.types import StreamToken from synapse.types import StreamToken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,23 +57,10 @@ class PaginationConfig(object):
@classmethod @classmethod
def from_request(cls, request, raise_invalid_params=True, def from_request(cls, request, raise_invalid_params=True,
default_limit=None): default_limit=None):
def get_param(name, default=None): direction = parse_string(request, "dir", default='f', allowed_values=['f', 'b'])
lst = request.args.get(name, [])
if len(lst) > 1:
raise SynapseError(
400, "%s must be specified only once" % (name,)
)
elif len(lst) == 1:
return lst[0]
else:
return default
direction = get_param("dir", 'f') from_tok = parse_string(request, "from")
if direction not in ['f', 'b']: to_tok = parse_string(request, "to")
raise SynapseError(400, "'dir' parameter is invalid.")
from_tok = get_param("from")
to_tok = get_param("to")
try: try:
if from_tok == "END": if from_tok == "END":
@ -88,12 +76,7 @@ class PaginationConfig(object):
except Exception: except Exception:
raise SynapseError(400, "'to' paramater is invalid") raise SynapseError(400, "'to' paramater is invalid")
limit = get_param("limit", None) limit = parse_integer(request, "limit", default=default_limit)
if limit is not None and not limit.isdigit():
raise SynapseError(400, "'limit' parameter must be an integer.")
if limit is None:
limit = default_limit
try: try:
return PaginationConfig(from_tok, to_tok, direction, limit) return PaginationConfig(from_tok, to_tok, direction, limit)