Merge pull request #798 from negzi/bugfix_create_user_feature

Fix set profile error with Requester.
pull/804/head
Erik Johnston 2016-05-24 10:13:53 +01:00
commit d16cc52b5d
2 changed files with 30 additions and 13 deletions

View File

@ -16,7 +16,7 @@
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
from twisted.internet import defer from twisted.internet import defer
from synapse.types import UserID from synapse.types import UserID, Requester
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
@ -360,7 +360,8 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_seconds): def get_or_create_user(self, localpart, displayname, duration_seconds):
"""Creates a new user or returns an access token for an existing one """Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
Args: Args:
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
@ -399,14 +400,14 @@ class RegistrationHandler(BaseHandler):
yield registered_user(self.distributor, user) yield registered_user(self.distributor, user)
else: else:
yield self.store.flush_user(user_id=user_id) yield self.store.user_delete_access_tokens(user_id=user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token) yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None: if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname) logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler profile_handler = self.hs.get_handlers().profile_handler
yield profile_handler.set_displayname( yield profile_handler.set_displayname(
user, user, displayname user, Requester(user, token, False), displayname
) )
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))

View File

@ -17,6 +17,7 @@ from twisted.internet import defer
from .. import unittest from .. import unittest
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -36,25 +37,21 @@ class RegistrationTestCase(unittest.TestCase):
self.mock_distributor = Mock() self.mock_distributor = Mock()
self.mock_distributor.declare("registered_user") self.mock_distributor.declare("registered_user")
self.mock_captcha_client = Mock() self.mock_captcha_client = Mock()
hs = yield setup_test_homeserver( self.hs = yield setup_test_homeserver(
handlers=None, handlers=None,
http_client=None, http_client=None,
expire_access_token=True) expire_access_token=True)
hs.handlers = RegistrationHandlers(hs) self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = hs.get_handlers().registration_handler self.handler = self.hs.get_handlers().registration_handler
hs.get_handlers().profile_handler = Mock() self.hs.get_handlers().profile_handler = Mock()
self.mock_handler = Mock(spec=[ self.mock_handler = Mock(spec=[
"generate_short_term_login_token", "generate_short_term_login_token",
]) ])
hs.get_handlers().auth_handler = self.mock_handler self.hs.get_handlers().auth_handler = self.mock_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self): def test_user_is_created_and_logged_in_if_doesnt_exist(self):
"""
Returns:
The user doess not exist in this case so it will register and log it in
"""
duration_ms = 200 duration_ms = 200
local_part = "someone" local_part = "someone"
display_name = "someone" display_name = "someone"
@ -65,3 +62,22 @@ class RegistrationTestCase(unittest.TestCase):
local_part, display_name, duration_ms) local_part, display_name, duration_ms)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@defer.inlineCallbacks
def test_if_user_exists(self):
store = self.hs.get_datastore()
frank = UserID.from_string("@frank:test")
yield store.register(
user_id=frank.to_string(),
token="jkv;g498752-43gj['eamb!-5",
password_hash=None)
duration_ms = 200
local_part = "frank"
display_name = "Frank"
user_id = "@frank:test"
mock_token = self.mock_handler.generate_short_term_login_token
mock_token.return_value = 'secret'
result_user_id, result_token = yield self.handler.get_or_create_user(
local_part, display_name, duration_ms)
self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret')