Refactor the user-interactive auth handling (#6105)

Pull the checkers out to their own classes, rather than having them lost in a
massive 1000-line class which does everything.

This is also preparation for some more intelligent advertising of flows, as per #6100
pull/6107/head
Richard van der Hoff 2019-09-25 11:33:03 +01:00 committed by GitHub
parent 8004d6ca2f
commit 2cd98812ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 265 additions and 141 deletions

1
changelog.d/6105.misc Normal file
View File

@ -0,0 +1 @@
Refactor the user-interactive auth handling.

View File

@ -21,10 +21,8 @@ import unicodedata
import attr
import bcrypt
import pymacaroons
from canonicaljson import json
from twisted.internet import defer
from twisted.web.client import PartialDownloadError
import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
@ -38,7 +36,8 @@ from synapse.api.errors import (
UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.logging.context import defer_to_thread
from synapse.module_api import ModuleApi
from synapse.types import UserID
@ -58,13 +57,12 @@ class AuthHandler(BaseHandler):
hs (synapse.server.HomeServer):
"""
super(AuthHandler, self).__init__(hs)
self.checkers = {
LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.MSISDN: self._check_msisdn,
LoginType.DUMMY: self._check_dummy_auth,
LoginType.TERMS: self._check_terms_auth,
}
self.checkers = {} # type: dict[str, UserInteractiveAuthChecker]
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
self.checkers[inst.AUTH_TYPE] = inst
self.bcrypt_rounds = hs.config.bcrypt_rounds
# This is not a cache per se, but a store of all current sessions that
@ -292,7 +290,7 @@ class AuthHandler(BaseHandler):
sess["creds"] = {}
creds = sess["creds"]
result = yield self.checkers[stagetype](authdict, clientip)
result = yield self.checkers[stagetype].check_auth(authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
@ -363,7 +361,7 @@ class AuthHandler(BaseHandler):
login_type = authdict["type"]
checker = self.checkers.get(login_type)
if checker is not None:
res = yield checker(authdict, clientip=clientip)
res = yield checker.check_auth(authdict, clientip=clientip)
return res
# build a v1-login-style dict out of the authdict and fall back to the
@ -376,125 +374,6 @@ class AuthHandler(BaseHandler):
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
return canonical_id
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip, **kwargs):
try:
user_response = authdict["response"]
except KeyError:
# Client tried to provide captcha but didn't give the parameter:
# bad request.
raise LoginError(
400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
)
logger.info(
"Submitting recaptcha response %s with remoteip %s", user_response, clientip
)
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
client = self.hs.get_simple_http_client()
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
args={
"secret": self.hs.config.recaptcha_private_key,
"response": user_response,
"remoteip": clientip,
},
)
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = json.loads(data)
if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body["success"] else "Failed",
resp_body.get("hostname"),
)
if resp_body["success"]:
return True
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
def _check_email_identity(self, authdict, **kwargs):
return self._check_threepid("email", authdict, **kwargs)
def _check_msisdn(self, authdict, **kwargs):
return self._check_threepid("msisdn", authdict)
def _check_dummy_auth(self, authdict, **kwargs):
return defer.succeed(True)
def _check_terms_auth(self, authdict, **kwargs):
return defer.succeed(True)
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict, **kwargs):
if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict["threepid_creds"]
identity_handler = self.hs.get_handlers().identity_handler
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
if medium == "email":
threepid = yield identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif medium == "msisdn":
threepid = yield identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
else:
raise SynapseError(400, "Unrecognized threepid medium: %s" % (medium,))
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
row = yield self.store.get_threepid_validation_session(
medium,
threepid_creds["client_secret"],
sid=threepid_creds["sid"],
validated=True,
)
threepid = (
{
"medium": row["medium"],
"address": row["address"],
"validated_at": row["validated_at"],
}
if row
else None
)
if row:
# Valid threepid returned, delete from the db
yield self.store.delete_threepid_session(threepid_creds["sid"])
else:
raise SynapseError(
400, "Password resets are not enabled on this homeserver"
)
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
if threepid["medium"] != medium:
raise LoginError(
401,
"Expecting threepid of type '%s', got '%s'"
% (medium, threepid["medium"]),
errcode=Codes.UNAUTHORIZED,
)
threepid["threepid_creds"] = authdict["threepid_creds"]
return threepid
def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key}

View File

@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module implements user-interactive auth verification.
TODO: move more stuff out of AuthHandler in here.
"""
from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401

View File

@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from canonicaljson import json
from twisted.internet import defer
from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.config.emailconfig import ThreepidBehaviour
logger = logging.getLogger(__name__)
class UserInteractiveAuthChecker:
"""Abstract base class for an interactive auth checker"""
def __init__(self, hs):
pass
def check_auth(self, authdict, clientip):
"""Given the authentication dict from the client, attempt to check this step
Args:
authdict (dict): authentication dictionary from the client
clientip (str): The IP address of the client.
Raises:
SynapseError if authentication failed
Returns:
Deferred: the result of authentication (to pass back to the client?)
"""
raise NotImplementedError()
class DummyAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.DUMMY
def check_auth(self, authdict, clientip):
return defer.succeed(True)
class TermsAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.TERMS
def check_auth(self, authdict, clientip):
return defer.succeed(True)
class RecaptchaAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.RECAPTCHA
def __init__(self, hs):
super().__init__(hs)
self._http_client = hs.get_simple_http_client()
self._url = hs.config.recaptcha_siteverify_api
self._secret = hs.config.recaptcha_private_key
@defer.inlineCallbacks
def check_auth(self, authdict, clientip):
try:
user_response = authdict["response"]
except KeyError:
# Client tried to provide captcha but didn't give the parameter:
# bad request.
raise LoginError(
400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
)
logger.info(
"Submitting recaptcha response %s with remoteip %s", user_response, clientip
)
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
resp_body = yield self._http_client.post_urlencoded_get_json(
self._url,
args={
"secret": self._secret,
"response": user_response,
"remoteip": clientip,
},
)
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = json.loads(data)
if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body["success"] else "Failed",
resp_body.get("hostname"),
)
if resp_body["success"]:
return True
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
class _BaseThreepidAuthChecker:
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict):
if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict["threepid_creds"]
identity_handler = self.hs.get_handlers().identity_handler
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
if medium == "email":
threepid = yield identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif medium == "msisdn":
threepid = yield identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
else:
raise SynapseError(400, "Unrecognized threepid medium: %s" % (medium,))
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
row = yield self.store.get_threepid_validation_session(
medium,
threepid_creds["client_secret"],
sid=threepid_creds["sid"],
validated=True,
)
threepid = (
{
"medium": row["medium"],
"address": row["address"],
"validated_at": row["validated_at"],
}
if row
else None
)
if row:
# Valid threepid returned, delete from the db
yield self.store.delete_threepid_session(threepid_creds["sid"])
else:
raise SynapseError(
400, "Password resets are not enabled on this homeserver"
)
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
if threepid["medium"] != medium:
raise LoginError(
401,
"Expecting threepid of type '%s', got '%s'"
% (medium, threepid["medium"]),
errcode=Codes.UNAUTHORIZED,
)
threepid["threepid_creds"] = authdict["threepid_creds"]
return threepid
class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
AUTH_TYPE = LoginType.EMAIL_IDENTITY
def __init__(self, hs):
UserInteractiveAuthChecker.__init__(self, hs)
_BaseThreepidAuthChecker.__init__(self, hs)
def check_auth(self, authdict, clientip):
return self._check_threepid("email", authdict)
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
AUTH_TYPE = LoginType.MSISDN
def __init__(self, hs):
UserInteractiveAuthChecker.__init__(self, hs)
_BaseThreepidAuthChecker.__init__(self, hs)
def check_auth(self, authdict, clientip):
return self._check_threepid("msisdn", authdict)
INTERACTIVE_AUTH_CHECKERS = [
DummyAuthChecker,
TermsAuthChecker,
RecaptchaAuthChecker,
EmailIdentityAuthChecker,
MsisdnAuthChecker,
]
"""A list of UserInteractiveAuthChecker classes"""

View File

@ -18,11 +18,22 @@ from twisted.internet.defer import succeed
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client.v2_alpha import auth, register
from tests import unittest
class DummyRecaptchaChecker(UserInteractiveAuthChecker):
def __init__(self, hs):
super().__init__(hs)
self.recaptcha_attempts = []
def check_auth(self, authdict, clientip):
self.recaptcha_attempts.append((authdict, clientip))
return succeed(True)
class FallbackAuthTests(unittest.HomeserverTestCase):
servlets = [
@ -44,15 +55,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor, clock, hs):
self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
self.recaptcha_attempts = []
def _recaptcha(authdict, clientip):
self.recaptcha_attempts.append((authdict, clientip))
return succeed(True)
auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
@unittest.INFO
def test_fallback_captcha(self):
@ -89,8 +94,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
self.assertEqual(request.code, 200)
# The recaptcha handler is called with the response given
self.assertEqual(len(self.recaptcha_attempts), 1)
self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
attempts = self.recaptcha_checker.recaptcha_attempts
self.assertEqual(len(attempts), 1)
self.assertEqual(attempts[0][0]["response"], "a")
# also complete the dummy auth
request, channel = self.make_request(