Merge pull request #817 from matrix-org/dbkr/split_out_auth_handler

Split out the auth handler
pull/819/head
David Baker 2016-06-02 14:31:35 +01:00
commit fb2193cc63
10 changed files with 23 additions and 24 deletions

View File

@ -24,7 +24,6 @@ from .federation import FederationHandler
from .profile import ProfileHandler from .profile import ProfileHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
from .admin import AdminHandler from .admin import AdminHandler
from .auth import AuthHandler
from .identity import IdentityHandler from .identity import IdentityHandler
from .receipts import ReceiptsHandler from .receipts import ReceiptsHandler
from .search import SearchHandler from .search import SearchHandler
@ -50,7 +49,6 @@ class Handlers(object):
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs) self.receipts_handler = ReceiptsHandler(hs)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs) self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs) self.search_handler = SearchHandler(hs)
self.room_context_handler = RoomContextHandler(hs) self.room_context_handler = RoomContextHandler(hs)

View File

@ -413,7 +413,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
def auth_handler(self): def auth_handler(self):
return self.hs.get_handlers().auth_handler return self.hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id): def guest_access_token_for(self, medium, address, inviter_user_id):

View File

@ -58,6 +58,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler()
def on_GET(self, request): def on_GET(self, request):
flows = [] flows = []
@ -143,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet):
user_id, self.hs.hostname user_id, self.hs.hostname
).to_string() ).to_string()
auth_handler = self.handlers.auth_handler auth_handler = self.auth_handler
user_id, access_token, refresh_token = yield auth_handler.login_with_password( user_id, access_token, refresh_token = yield auth_handler.login_with_password(
user_id=user_id, user_id=user_id,
password=login_submission["password"]) password=login_submission["password"])
@ -160,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_token_login(self, login_submission): def do_token_login(self, login_submission):
token = login_submission['token'] token = login_submission['token']
auth_handler = self.handlers.auth_handler auth_handler = self.auth_handler
user_id = ( user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token) yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
@ -194,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists: if user_exists:
user_id, access_token, refresh_token = ( user_id, access_token, refresh_token = (
@ -243,7 +244,7 @@ class LoginRestServlet(ClientV1RestServlet):
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists: if user_exists:
user_id, access_token, refresh_token = ( user_id, access_token, refresh_token = (
@ -412,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
if not user_exists: if not user_exists:
user_id, _ = ( user_id, _ = (

View File

@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet):
super(PasswordRestServlet, self).__init__() super(PasswordRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -97,7 +97,7 @@ class ThreepidRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View File

@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
super(AuthRestServlet, self).__init__() super(AuthRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -49,7 +49,7 @@ class RegisterRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler

View File

@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
try: try:
old_refresh_token = body["refresh_token"] old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_handlers().auth_handler auth_handler = self.hs.get_auth_handler()
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token( (user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token) old_refresh_token, auth_handler.generate_refresh_token)
new_access_token = yield auth_handler.issue_access_token(user_id) new_access_token = yield auth_handler.issue_access_token(user_id)

View File

@ -33,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler from synapse.handlers.typing import TypingHandler
from synapse.handlers.room import RoomListHandler from synapse.handlers.room import RoomListHandler
from synapse.handlers.auth import AuthHandler
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.state import StateHandler from synapse.state import StateHandler
from synapse.storage import DataStore from synapse.storage import DataStore
@ -89,6 +90,7 @@ class HomeServer(object):
'sync_handler', 'sync_handler',
'typing_handler', 'typing_handler',
'room_list_handler', 'room_list_handler',
'auth_handler',
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
@ -190,6 +192,9 @@ class HomeServer(object):
def build_room_list_handler(self): def build_room_list_handler(self):
return RoomListHandler(self) return RoomListHandler(self)
def build_auth_handler(self):
return AuthHandler(self)
def build_application_service_api(self): def build_application_service_api(self):
return ApplicationServiceApi(self) return ApplicationServiceApi(self)

View File

@ -33,7 +33,6 @@ class RegisterRestServletTestCase(unittest.TestCase):
# do the dance to hook it up to the hs global # do the dance to hook it up to the hs global
self.handlers = Mock( self.handlers = Mock(
auth_handler=self.auth_handler,
registration_handler=self.registration_handler, registration_handler=self.registration_handler,
identity_handler=self.identity_handler, identity_handler=self.identity_handler,
login_handler=self.login_handler login_handler=self.login_handler
@ -42,6 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.hostname = "superbig~testing~thing.com" self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.config.enable_registration = True self.hs.config.enable_registration = True
# init the thing we're testing # init the thing we're testing

View File

@ -81,16 +81,11 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
) )
# bcrypt is far too slow to be doing in unit tests # bcrypt is far too slow to be doing in unit tests
def swap_out_hash_for_testing(old_build_handlers): # Need to let the HS build an auth handler and then mess with it
def build_handlers(): # because AuthHandler's constructor requires the HS, so we can't make one
handlers = old_build_handlers() # beforehand and pass it in to the HS's constructor (chicken / egg)
auth_handler = handlers.auth_handler hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
auth_handler.hash = lambda p: hashlib.md5(p).hexdigest() hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
return handlers
return build_handlers
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
fed = kargs.get("resource_for_federation", None) fed = kargs.get("resource_for_federation", None)
if fed: if fed: