Merge pull request #2609 from matrix-org/rav/refactor_login

Refactor some logic from LoginRestServlet into AuthHandler
pull/2611/head
David Baker 2017-10-31 13:51:36 +00:00 committed by GitHub
commit c31a7c3ff6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 58 deletions

View File

@ -77,6 +77,12 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
login_types = set()
if self._password_enabled:
login_types.add(LoginType.PASSWORD)
self._supported_login_types = frozenset(login_types)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@ -266,10 +272,11 @@ class AuthHandler(BaseHandler):
user_id = authdict["user"] user_id = authdict["user"]
password = authdict["password"] password = authdict["password"]
if not user_id.startswith('@'):
user_id = UserID(user_id, self.hs.hostname).to_string()
return self._check_password(user_id, password) return self.validate_login(user_id, {
"type": LoginType.PASSWORD,
"password": password,
})
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip): def _check_recaptcha(self, authdict, clientip):
@ -398,23 +405,6 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
def validate_password_login(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
user_id (str): complete @user:id
password (str): Password
Returns:
defer.Deferred: (str) canonical user id
Raises:
StoreError if there was a problem accessing the database
LoginError if there was an authentication problem.
"""
return self._check_password(user_id, password)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_access_token_for_user_id(self, user_id, device_id=None, def get_access_token_for_user_id(self, user_id, device_id=None,
initial_display_name=None): initial_display_name=None):
@ -501,26 +491,60 @@ class AuthHandler(BaseHandler):
) )
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks def get_supported_login_types(self):
def _check_password(self, user_id, password): """Get a the login types supported for the /login API
"""Authenticate a user against the LDAP and local databases.
user_id is checked case insensitively against the local database, but By default this is just 'm.login.password' (unless password_enabled is
will throw if there are multiple inexact matches. False in the config file), but password auth providers can provide
other login types.
Returns:
Iterable[str]: login types
"""
return self._supported_login_types
@defer.inlineCallbacks
def validate_login(self, user_id, login_submission):
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
m.login.password auth types.
Args: Args:
user_id (str): complete @user:id user_id (str): user_id supplied by the user
login_submission (dict): the whole of the login submission
(including 'type' and other relevant fields)
Returns: Returns:
(str) the canonical_user_id Deferred[str]: canonical user id
Raises: Raises:
LoginError if login fails StoreError if there was a problem accessing the database
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
""" """
if not user_id.startswith('@'):
user_id = UserID(
user_id, self.hs.hostname
).to_string()
login_type = login_submission.get("type")
if login_type != LoginType.PASSWORD:
raise SynapseError(400, "Bad login type.")
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
if "password" not in login_submission:
raise SynapseError(400, "Missing parameter: password")
password = login_submission["password"]
for provider in self.password_providers: for provider in self.password_providers:
is_valid = yield provider.check_password(user_id, password) is_valid = yield provider.check_password(user_id, password)
if is_valid: if is_valid:
defer.returnValue(user_id) defer.returnValue(user_id)
canonical_user_id = yield self._check_local_password(user_id, password) canonical_user_id = yield self._check_local_password(
user_id, password,
)
if canonical_user_id: if canonical_user_id:
defer.returnValue(canonical_user_id) defer.returnValue(canonical_user_id)

View File

@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier):
class LoginRestServlet(ClientV1RestServlet): class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$") PATTERNS = client_path_patterns("/login$")
PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2" SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas" CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token" TOKEN_TYPE = "m.login.token"
@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs) super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret self.jwt_secret = hs.config.jwt_secret
@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet):
# fall back to the fallback API if they don't understand one of the # fall back to the fallback API if they don't understand one of the
# login flow types returned. # login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE}) flows.append({"type": LoginRestServlet.TOKEN_TYPE})
if self.password_enabled:
flows.append({"type": LoginRestServlet.PASS_TYPE}) flows.extend((
{"type": t} for t in self.auth_handler.get_supported_login_types()
))
return (200, {"flows": flows}) return (200, {"flows": flows})
@ -133,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
login_submission = parse_json_object_from_request(request) login_submission = parse_json_object_from_request(request)
try: try:
if login_submission["type"] == LoginRestServlet.PASS_TYPE: if self.saml2_enabled and (login_submission["type"] ==
if not self.password_enabled: LoginRestServlet.SAML2_TYPE):
raise SynapseError(400, "Password login has been disabled.")
result = yield self.do_password_login(login_submission)
defer.returnValue(result)
elif self.saml2_enabled and (login_submission["type"] ==
LoginRestServlet.SAML2_TYPE):
relay_state = "" relay_state = ""
if "relay_state" in login_submission: if "relay_state" in login_submission:
relay_state = "&RelayState=" + urllib.quote( relay_state = "&RelayState=" + urllib.quote(
@ -157,15 +151,21 @@ class LoginRestServlet(ClientV1RestServlet):
result = yield self.do_token_login(login_submission) result = yield self.do_token_login(login_submission)
defer.returnValue(result) defer.returnValue(result)
else: else:
raise SynapseError(400, "Bad login type.") result = yield self._do_other_login(login_submission)
defer.returnValue(result)
except KeyError: except KeyError:
raise SynapseError(400, "Missing JSON keys.") raise SynapseError(400, "Missing JSON keys.")
@defer.inlineCallbacks @defer.inlineCallbacks
def do_password_login(self, login_submission): def _do_other_login(self, login_submission):
if "password" not in login_submission: """Handle non-token/saml/jwt logins
raise SynapseError(400, "Missing parameter: password")
Args:
login_submission:
Returns:
(int, object): HTTP code/response
"""
login_submission_legacy_convert(login_submission) login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission: if "identifier" not in login_submission:
@ -208,25 +208,22 @@ class LoginRestServlet(ClientV1RestServlet):
if "user" not in identifier: if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key") raise SynapseError(400, "User identifier is missing 'user' key")
user_id = identifier["user"]
if not user_id.startswith('@'):
user_id = UserID(
user_id, self.hs.hostname
).to_string()
auth_handler = self.auth_handler auth_handler = self.auth_handler
user_id = yield auth_handler.validate_password_login( canonical_user_id = yield auth_handler.validate_login(
user_id=user_id, identifier["user"],
password=login_submission["password"], login_submission,
)
device_id = yield self._register_device(
canonical_user_id, login_submission,
) )
device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id( access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id, canonical_user_id, device_id,
login_submission.get("initial_device_display_name"), login_submission.get("initial_device_display_name"),
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": canonical_user_id,
"access_token": access_token, "access_token": access_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,