diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1504b00d7e..98d99dd0a8 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -157,22 +157,13 @@ class AuthHandler(BaseHandler): if "user" not in authdict or "password" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) - user = authdict["user"] + user_id = authdict["user"] password = authdict["password"] - if not user.startswith('@'): - user = UserID.create(user, self.hs.hostname).to_string() + if not user_id.startswith('@'): + user_id = UserID.create(user_id, self.hs.hostname).to_string() - user_info = yield self.store.get_user_by_id(user_id=user) - if not user_info: - logger.warn("Attempted to login as %s but they do not exist", user) - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - - stored_hash = user_info["password_hash"] - if bcrypt.checkpw(password, stored_hash): - defer.returnValue(user) - else: - logger.warn("Failed password login for user %s", user) - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + self._check_password(user_id, password) + defer.returnValue(user_id) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -292,6 +283,16 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ + self._check_password(user_id, password) + + reg_handler = self.hs.get_handlers().registration_handler + access_token = reg_handler.generate_token(user_id) + logger.info("Adding token %s for user %s", access_token, user_id) + yield self.store.add_access_token_to_user(user_id, access_token) + defer.returnValue(access_token) + + def _check_password(self, user_id, password): + """Checks that user_id has passed password, raises LoginError if not.""" user_info = yield self.store.get_user_by_id(user_id=user_id) if not user_info: logger.warn("Attempted to login as %s but they do not exist", user_id) @@ -302,12 +303,6 @@ class AuthHandler(BaseHandler): logger.warn("Failed password login for user %s", user_id) raise LoginError(403, "", errcode=Codes.FORBIDDEN) - reg_handler = self.hs.get_handlers().registration_handler - access_token = reg_handler.generate_token(user_id) - logger.info("Adding token %s for user %s", access_token, user_id) - yield self.store.add_access_token_to_user(user_id, access_token) - defer.returnValue(access_token) - @defer.inlineCallbacks def set_password(self, user_id, newpassword): password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())