diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index e2bbfc9d93..8a53617629 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -28,24 +28,54 @@ import logging logger = logging.getLogger(__name__) +class PasswordRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/password/email/requestToken$") + + def __init__(self, hs): + super(PasswordRequestTokenRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if len(absent) > 0: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is None: + raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + def on_OPTIONS(self, _): + return 200, {} + + class PasswordRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/password") + PATTERNS = client_v2_patterns("/account/password$") def __init__(self, hs): super(PasswordRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.identity_handler = hs.get_handlers().identity_handler @defer.inlineCallbacks def on_POST(self, request): yield run_on_reactor() - if '/account/password/email/requestToken' in request.path: - ret = yield self.onPasswordEmailTokenRequest(request) - defer.returnValue(ret) - body = parse_json_object_from_request(request) authed, result, params, _ = yield self.auth_handler.check_auth([ @@ -90,8 +120,20 @@ class PasswordRestServlet(RestServlet): defer.returnValue((200, {})) + def on_OPTIONS(self, _): + return 200, {} + + +class ThreepidRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") + + def __init__(self, hs): + self.hs = hs + super(ThreepidRequestTokenRestServlet, self).__init__() + self.identity_handler = hs.get_handlers().identity_handler + @defer.inlineCallbacks - def onPasswordEmailTokenRequest(self, request): + def on_POST(self, request): body = parse_json_object_from_request(request) required = ['id_server', 'client_secret', 'email', 'send_attempt'] @@ -107,8 +149,10 @@ class PasswordRestServlet(RestServlet): 'email', body['email'] ) - if existingUid is None: - raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) + logger.error("existing %r", existingUid) + + if existingUid is not None: + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) ret = yield self.identity_handler.requestEmailToken(**body) defer.returnValue((200, ret)) @@ -118,7 +162,7 @@ class PasswordRestServlet(RestServlet): class ThreepidRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid") + PATTERNS = client_v2_patterns("/account/3pid$") def __init__(self, hs): super(ThreepidRestServlet, self).__init__() @@ -143,10 +187,6 @@ class ThreepidRestServlet(RestServlet): def on_POST(self, request): yield run_on_reactor() - if '/account/3pid/email/requestToken' in request.path: - ret = yield self.onThreepidEmailTokenRequest(request) - defer.returnValue(ret) - body = parse_json_object_from_request(request) threePidCreds = body.get('threePidCreds') @@ -187,30 +227,9 @@ class ThreepidRestServlet(RestServlet): defer.returnValue((200, {})) - @defer.inlineCallbacks - def onThreepidEmailTokenRequest(self, request): - body = parse_json_object_from_request(request) - - required = ['id_server', 'client_secret', 'email', 'send_attempt'] - absent = [] - for k in required: - if k not in body: - absent.append(k) - - if len(absent) > 0: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) - - existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', body['email'] - ) - - if existingUid is not None: - raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - - ret = yield self.identity_handler.requestEmailToken(**body) - defer.returnValue((200, ret)) - def register_servlets(hs, http_server): + PasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) + ThreepidRequestTokenRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 2088c316d1..e5944b99b1 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -41,8 +41,43 @@ else: logger = logging.getLogger(__name__) +class RegisterRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/register/email/requestToken$") + + def __init__(self, hs): + super(RegisterRequestTokenRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if len(absent) > 0: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is not None: + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + def on_OPTIONS(self, _): + return 200, {} + + class RegisterRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/register") + PATTERNS = client_v2_patterns("/register$") def __init__(self, hs): super(RegisterRestServlet, self).__init__() @@ -70,10 +105,6 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind,) ) - if '/register/email/requestToken' in request.path: - ret = yield self.onEmailTokenRequest(request) - defer.returnValue(ret) - body = parse_json_object_from_request(request) # we do basic sanity checks here because the auth layer will store these @@ -305,29 +336,6 @@ class RegisterRestServlet(RestServlet): "refresh_token": refresh_token, }) - @defer.inlineCallbacks - def onEmailTokenRequest(self, request): - body = parse_json_object_from_request(request) - - required = ['id_server', 'client_secret', 'email', 'send_attempt'] - absent = [] - for k in required: - if k not in body: - absent.append(k) - - if len(absent) > 0: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) - - existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', body['email'] - ) - - if existingUid is not None: - raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - - ret = yield self.identity_handler.requestEmailToken(**body) - defer.returnValue((200, ret)) - @defer.inlineCallbacks def _do_guest_registration(self): if not self.hs.config.allow_guest_access: @@ -345,4 +353,5 @@ class RegisterRestServlet(RestServlet): def register_servlets(hs, http_server): + RegisterRequestTokenRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)