Merge pull request #649 from matrix-org/dbkr/idempotent_registration
Make registration idempotentpull/652/head
commit
48b2e853a8
|
@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class AuthHandler(BaseHandler):
|
||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||
|
||||
def __init__(self, hs):
|
||||
super(AuthHandler, self).__init__(hs)
|
||||
|
@ -66,15 +67,18 @@ class AuthHandler(BaseHandler):
|
|||
'auth' key: this method prompts for auth if none is sent.
|
||||
clientip (str): The IP address of the client.
|
||||
Returns:
|
||||
A tuple of (authed, dict, dict) where authed is true if the client
|
||||
has successfully completed an auth flow. If it is true, the first
|
||||
dict contains the authenticated credentials of each stage.
|
||||
A tuple of (authed, dict, dict, session_id) where authed is true if
|
||||
the client has successfully completed an auth flow. If it is true
|
||||
the first dict contains the authenticated credentials of each stage.
|
||||
|
||||
If authed is false, the first dictionary is the server response to
|
||||
the login request and should be passed back to the client.
|
||||
|
||||
In either case, the second dict contains the parameters for this
|
||||
request (which may have been given only in a previous call).
|
||||
|
||||
session_id is the ID of this session, either passed in by the client
|
||||
or assigned by the call to check_auth
|
||||
"""
|
||||
|
||||
authdict = None
|
||||
|
@ -103,7 +107,10 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
if not authdict:
|
||||
defer.returnValue(
|
||||
(False, self._auth_dict_for_flows(flows, session), clientdict)
|
||||
(
|
||||
False, self._auth_dict_for_flows(flows, session),
|
||||
clientdict, session['id']
|
||||
)
|
||||
)
|
||||
|
||||
if 'creds' not in session:
|
||||
|
@ -122,12 +129,11 @@ class AuthHandler(BaseHandler):
|
|||
for f in flows:
|
||||
if len(set(f) - set(creds.keys())) == 0:
|
||||
logger.info("Auth completed with creds: %r", creds)
|
||||
self._remove_session(session)
|
||||
defer.returnValue((True, creds, clientdict))
|
||||
defer.returnValue((True, creds, clientdict, session['id']))
|
||||
|
||||
ret = self._auth_dict_for_flows(flows, session)
|
||||
ret['completed'] = creds.keys()
|
||||
defer.returnValue((False, ret, clientdict))
|
||||
defer.returnValue((False, ret, clientdict, session['id']))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_oob_auth(self, stagetype, authdict, clientip):
|
||||
|
@ -154,6 +160,29 @@ class AuthHandler(BaseHandler):
|
|||
defer.returnValue(True)
|
||||
defer.returnValue(False)
|
||||
|
||||
def set_session_data(self, session_id, key, value):
|
||||
"""
|
||||
Store a key-value pair into the sessions data associated with this
|
||||
request. This data is stored server-side and cannot be modified by
|
||||
the client.
|
||||
:param session_id: (string) The ID of this session as returned from check_auth
|
||||
:param key: (string) The key to store the data under
|
||||
:param value: (any) The data to store
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
sess.setdefault('serverdict', {})[key] = value
|
||||
self._save_session(sess)
|
||||
|
||||
def get_session_data(self, session_id, key, default=None):
|
||||
"""
|
||||
Retrieve data stored with set_session_data
|
||||
:param session_id: (string) The ID of this session as returned from check_auth
|
||||
:param key: (string) The key to store the data under
|
||||
:param default: (any) Value to return if the key has not been set
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
return sess.setdefault('serverdict', {}).get(key, default)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_password_auth(self, authdict, _):
|
||||
if "user" not in authdict or "password" not in authdict:
|
||||
|
@ -455,11 +484,18 @@ class AuthHandler(BaseHandler):
|
|||
def _save_session(self, session):
|
||||
# TODO: Persistent storage
|
||||
logger.debug("Saving session %s", session)
|
||||
session["last_used"] = self.hs.get_clock().time_msec()
|
||||
self.sessions[session["id"]] = session
|
||||
self._prune_sessions()
|
||||
|
||||
def _remove_session(self, session):
|
||||
logger.debug("Removing session %s", session)
|
||||
del self.sessions[session["id"]]
|
||||
def _prune_sessions(self):
|
||||
for sid, sess in self.sessions.items():
|
||||
last_used = 0
|
||||
if 'last_used' in sess:
|
||||
last_used = sess['last_used']
|
||||
now = self.hs.get_clock().time_msec()
|
||||
if last_used < now - AuthHandler.SESSION_EXPIRE_MS:
|
||||
del self.sessions[sid]
|
||||
|
||||
def hash(self, password):
|
||||
"""Computes a secure hash of password.
|
||||
|
|
|
@ -43,7 +43,7 @@ class PasswordRestServlet(RestServlet):
|
|||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
authed, result, params = yield self.auth_handler.check_auth([
|
||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||
[LoginType.PASSWORD],
|
||||
[LoginType.EMAIL_IDENTITY]
|
||||
], body, self.hs.get_ip_from_request(request))
|
||||
|
|
|
@ -139,7 +139,7 @@ class RegisterRestServlet(RestServlet):
|
|||
[LoginType.EMAIL_IDENTITY]
|
||||
]
|
||||
|
||||
authed, result, params = yield self.auth_handler.check_auth(
|
||||
authed, result, params, session_id = yield self.auth_handler.check_auth(
|
||||
flows, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
|
@ -147,6 +147,26 @@ class RegisterRestServlet(RestServlet):
|
|||
defer.returnValue((401, result))
|
||||
return
|
||||
|
||||
# have we already registered a user for this session
|
||||
registered_user_id = self.auth_handler.get_session_data(
|
||||
session_id, "registered_user_id", None
|
||||
)
|
||||
if registered_user_id is not None:
|
||||
logger.info(
|
||||
"Already registered user ID %r for this session",
|
||||
registered_user_id
|
||||
)
|
||||
access_token = yield self.auth_handler.issue_access_token(registered_user_id)
|
||||
refresh_token = yield self.auth_handler.issue_refresh_token(
|
||||
registered_user_id
|
||||
)
|
||||
defer.returnValue((200, {
|
||||
"user_id": registered_user_id,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"refresh_token": refresh_token,
|
||||
}))
|
||||
|
||||
# NB: This may be from the auth handler and NOT from the POST
|
||||
if 'password' not in params:
|
||||
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
|
||||
|
@ -161,6 +181,12 @@ class RegisterRestServlet(RestServlet):
|
|||
guest_access_token=guest_access_token,
|
||||
)
|
||||
|
||||
# remember that we've now registered that user account, and with what
|
||||
# user ID (since the user may not have specified)
|
||||
self.auth_handler.set_session_data(
|
||||
session_id, "registered_user_id", user_id
|
||||
)
|
||||
|
||||
if result and LoginType.EMAIL_IDENTITY in result:
|
||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||
|
||||
|
|
|
@ -22,9 +22,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
side_effect=lambda x: defer.succeed(self.appservice))
|
||||
)
|
||||
|
||||
self.auth_result = (False, None, None)
|
||||
self.auth_result = (False, None, None, None)
|
||||
self.auth_handler = Mock(
|
||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result)
|
||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
||||
get_session_data=Mock(return_value=None)
|
||||
)
|
||||
self.registration_handler = Mock()
|
||||
self.identity_handler = Mock()
|
||||
|
@ -112,7 +113,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
self.auth_result = (True, None, {
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
})
|
||||
}, None)
|
||||
self.registration_handler.register = Mock(return_value=(user_id, token))
|
||||
|
||||
(code, result) = yield self.servlet.on_POST(self.request)
|
||||
|
@ -135,7 +136,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
self.auth_result = (True, None, {
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
})
|
||||
}, None)
|
||||
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
||||
d = self.servlet.on_POST(self.request)
|
||||
return self.assertFailure(d, SynapseError)
|
||||
|
|
Loading…
Reference in New Issue