Merge pull request #649 from matrix-org/dbkr/idempotent_registration

Make registration idempotent
pull/652/head
David Baker 2016-03-16 16:35:45 +00:00
commit 48b2e853a8
4 changed files with 79 additions and 16 deletions

View File

@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs):
super(AuthHandler, self).__init__(hs)
@ -66,15 +67,18 @@ class AuthHandler(BaseHandler):
'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client.
Returns:
A tuple of (authed, dict, dict) where authed is true if the client
has successfully completed an auth flow. If it is true, the first
dict contains the authenticated credentials of each stage.
A tuple of (authed, dict, dict, session_id) where authed is true if
the client has successfully completed an auth flow. If it is true
the first dict contains the authenticated credentials of each stage.
If authed is false, the first dictionary is the server response to
the login request and should be passed back to the client.
In either case, the second dict contains the parameters for this
request (which may have been given only in a previous call).
session_id is the ID of this session, either passed in by the client
or assigned by the call to check_auth
"""
authdict = None
@ -103,7 +107,10 @@ class AuthHandler(BaseHandler):
if not authdict:
defer.returnValue(
(False, self._auth_dict_for_flows(flows, session), clientdict)
(
False, self._auth_dict_for_flows(flows, session),
clientdict, session['id']
)
)
if 'creds' not in session:
@ -122,12 +129,11 @@ class AuthHandler(BaseHandler):
for f in flows:
if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds)
self._remove_session(session)
defer.returnValue((True, creds, clientdict))
defer.returnValue((True, creds, clientdict, session['id']))
ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys()
defer.returnValue((False, ret, clientdict))
defer.returnValue((False, ret, clientdict, session['id']))
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
@ -154,6 +160,29 @@ class AuthHandler(BaseHandler):
defer.returnValue(True)
defer.returnValue(False)
def set_session_data(self, session_id, key, value):
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
the client.
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under
:param value: (any) The data to store
"""
sess = self._get_session_info(session_id)
sess.setdefault('serverdict', {})[key] = value
self._save_session(sess)
def get_session_data(self, session_id, key, default=None):
"""
Retrieve data stored with set_session_data
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under
:param default: (any) Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
@ -455,11 +484,18 @@ class AuthHandler(BaseHandler):
def _save_session(self, session):
# TODO: Persistent storage
logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session
self._prune_sessions()
def _remove_session(self, session):
logger.debug("Removing session %s", session)
del self.sessions[session["id"]]
def _prune_sessions(self):
for sid, sess in self.sessions.items():
last_used = 0
if 'last_used' in sess:
last_used = sess['last_used']
now = self.hs.get_clock().time_msec()
if last_used < now - AuthHandler.SESSION_EXPIRE_MS:
del self.sessions[sid]
def hash(self, password):
"""Computes a secure hash of password.

View File

@ -43,7 +43,7 @@ class PasswordRestServlet(RestServlet):
body = parse_json_object_from_request(request)
authed, result, params = yield self.auth_handler.check_auth([
authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY]
], body, self.hs.get_ip_from_request(request))

View File

@ -139,7 +139,7 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY]
]
authed, result, params = yield self.auth_handler.check_auth(
authed, result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
@ -147,6 +147,26 @@ class RegisterRestServlet(RestServlet):
defer.returnValue((401, result))
return
# have we already registered a user for this session
registered_user_id = self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
registered_user_id
)
access_token = yield self.auth_handler.issue_access_token(registered_user_id)
refresh_token = yield self.auth_handler.issue_refresh_token(
registered_user_id
)
defer.returnValue((200, {
"user_id": registered_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"refresh_token": refresh_token,
}))
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
@ -161,6 +181,12 @@ class RegisterRestServlet(RestServlet):
guest_access_token=guest_access_token,
)
# remember that we've now registered that user account, and with what
# user ID (since the user may not have specified)
self.auth_handler.set_session_data(
session_id, "registered_user_id", user_id
)
if result and LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]

View File

@ -22,9 +22,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
side_effect=lambda x: defer.succeed(self.appservice))
)
self.auth_result = (False, None, None)
self.auth_result = (False, None, None, None)
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result)
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None)
)
self.registration_handler = Mock()
self.identity_handler = Mock()
@ -112,7 +113,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.auth_result = (True, None, {
"username": "kermit",
"password": "monkey"
})
}, None)
self.registration_handler.register = Mock(return_value=(user_id, token))
(code, result) = yield self.servlet.on_POST(self.request)
@ -135,7 +136,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.auth_result = (True, None, {
"username": "kermit",
"password": "monkey"
})
}, None)
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError)