From 8aee5aa06807210c17ad0e58e4f237fcf2d052f9 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 9 Sep 2016 16:29:10 +0100 Subject: [PATCH] Add helper function for getting access_tokens from requests Rather than reimplementing the token parsing in the various places. This will make it easier to change the token parsing to allow access_tokens in HTTP headers. --- synapse/api/auth.py | 58 +++++++++++++++++++--- synapse/rest/client/v1/logout.py | 10 +--- synapse/rest/client/v1/register.py | 12 ++--- synapse/rest/client/v1/transactions.py | 4 +- synapse/rest/client/v2_alpha/register.py | 6 ++- synapse/rest/client/v2_alpha/thirdparty.py | 4 +- 6 files changed, 67 insertions(+), 27 deletions(-) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index dcda40863f..98a50f0948 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -583,12 +583,15 @@ class Auth(object): """ # Can optionally look elsewhere in the request (e.g. headers) try: - user_id = yield self._get_appservice_user_id(request.args) + user_id = yield self._get_appservice_user_id(request) if user_id: request.authenticated_entity = user_id defer.returnValue(synapse.types.create_requester(user_id)) - access_token = request.args["access_token"][0] + access_token = get_access_token_from_request( + request, self.TOKEN_NOT_FOUND_HTTP_STATUS + ) + user_info = yield self.get_user_by_access_token(access_token, rights) user = user_info["user"] token_id = user_info["token_id"] @@ -629,17 +632,19 @@ class Auth(object): ) @defer.inlineCallbacks - def _get_appservice_user_id(self, request_args): + def _get_appservice_user_id(self, request): app_service = yield self.store.get_app_service_by_token( - request_args["access_token"][0] + get_access_token_from_request( + request, self.TOKEN_NOT_FOUND_HTTP_STATUS + ) ) if app_service is None: defer.returnValue(None) - if "user_id" not in request_args: + if "user_id" not in request.args: defer.returnValue(app_service.sender) - user_id = request_args["user_id"][0] + user_id = request.args["user_id"][0] if app_service.sender == user_id: defer.returnValue(app_service.sender) @@ -833,7 +838,9 @@ class Auth(object): @defer.inlineCallbacks def get_appservice_by_req(self, request): try: - token = request.args["access_token"][0] + token = get_access_token_from_request( + request, self.TOKEN_NOT_FOUND_HTTP_STATUS + ) service = yield self.store.get_app_service_by_token(token) if not service: logger.warn("Unrecognised appservice access token: %s" % (token,)) @@ -1142,3 +1149,40 @@ class Auth(object): "This server requires you to be a moderator in the room to" " edit its room list entry" ) + + +def has_access_token(request): + """Checks if the request has an access_token. + + Returns: + bool: False if no access_token was given, True otherwise. + """ + query_params = request.args.get("access_token") + return bool(query_params) + + +def get_access_token_from_request(request, token_not_found_http_status=401): + """Extracts the access_token from the request. + + Args: + request: The http request. + token_not_found_http_status(int): The HTTP status code to set in the + AuthError if the token isn't found. This is used in some of the + legacy APIs to change the status code to 403 from the default of + 401 since some of the old clients depended on auth errors returning + 403. + Returns: + str: The access_token + Raises: + AuthError: If there isn't an access_token in the request. + """ + query_params = request.args.get("access_token") + # Try to get the access_token from the query params. + if not query_params: + raise AuthError( + token_not_found_http_status, + "Missing access token.", + errcode=Codes.MISSING_TOKEN + ) + + return query_params[0] diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 9bff02ee4e..1358d0acab 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from synapse.api.errors import AuthError, Codes +from synapse.api.auth import get_access_token_from_request from .base import ClientV1RestServlet, client_path_patterns @@ -37,13 +37,7 @@ class LogoutRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - try: - access_token = request.args["access_token"][0] - except KeyError: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", - errcode=Codes.MISSING_TOKEN - ) + access_token = get_access_token_from_request(request) yield self.store.delete_access_token(access_token) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 71d58c8e8d..3046da7aec 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.api.constants import LoginType +from synapse.api.auth import get_access_token_from_request from .base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils from synapse.http.servlet import parse_json_object_from_request @@ -296,12 +297,11 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_app_service(self, request, register_json, session): - if "access_token" not in request.args: - raise SynapseError(400, "Expected application service token.") + as_token = get_access_token_from_request(request) + if "user" not in register_json: raise SynapseError(400, "Expected 'user' key.") - as_token = request.args["access_token"][0] user_localpart = register_json["user"].encode("utf-8") handler = self.handlers.registration_handler @@ -390,11 +390,9 @@ class CreateUserRestServlet(ClientV1RestServlet): def on_POST(self, request): user_json = parse_json_object_from_request(request) - if "access_token" not in request.args: - raise SynapseError(400, "Expected application service token.") - + access_token = get_access_token_from_request(request) app_service = yield self.store.get_app_service_by_token( - request.args["access_token"][0] + access_token ) if not app_service: raise SynapseError(403, "Invalid application service token.") diff --git a/synapse/rest/client/v1/transactions.py b/synapse/rest/client/v1/transactions.py index bdccf464a5..2f2c9d0881 100644 --- a/synapse/rest/client/v1/transactions.py +++ b/synapse/rest/client/v1/transactions.py @@ -17,6 +17,8 @@ to ensure idempotency when performing PUTs using the REST API.""" import logging +from synapse.api.auth import get_access_token_from_request + logger = logging.getLogger(__name__) @@ -90,6 +92,6 @@ class HttpTransactionStore(object): return response def _get_key(self, request): - token = request.args["access_token"][0] + token = get_access_token_from_request(request) path_without_txn_id = request.path.rsplit("/", 1)[0] return path_without_txn_id + "/" + token diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 2121bd75ea..68d18a9b82 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -15,6 +15,7 @@ from twisted.internet import defer +from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -131,7 +132,7 @@ class RegisterRestServlet(RestServlet): desired_username = body['username'] appservice = None - if 'access_token' in request.args: + if has_access_token(request): appservice = yield self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes and shared secret auth which @@ -143,10 +144,11 @@ class RegisterRestServlet(RestServlet): # 'user' key not 'username'). Since this is a new addition, we'll # fallback to 'username' if they gave one. desired_username = body.get("user", desired_username) + access_token = get_access_token_from_request(request) if isinstance(desired_username, basestring): result = yield self._do_appservice_registration( - desired_username, request.args["access_token"][0], body + desired_username, access_token, body ) defer.returnValue((200, result)) # we throw for non 200 responses return diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 4f6f1a7e17..b3e73c0271 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -57,7 +57,7 @@ class ThirdPartyUserServlet(RestServlet): yield self.auth.get_user_by_req(request) fields = request.args - del fields["access_token"] + fields.pop("access_token", None) results = yield self.appservice_handler.query_3pe( ThirdPartyEntityKind.USER, protocol, fields @@ -81,7 +81,7 @@ class ThirdPartyLocationServlet(RestServlet): yield self.auth.get_user_by_req(request) fields = request.args - del fields["access_token"] + fields.pop("access_token", None) results = yield self.appservice_handler.query_3pe( ThirdPartyEntityKind.LOCATION, protocol, fields