Fix set profile error with Requester.
Replace flush_user with delete access token due to function removal Add a new test case for if the user is already registeredpull/798/head
							parent
							
								
									09804c9862
								
							
						
					
					
						commit
						6fe04ffef2
					
				| 
						 | 
					@ -16,7 +16,7 @@
 | 
				
			||||||
"""Contains functions for registering clients."""
 | 
					"""Contains functions for registering clients."""
 | 
				
			||||||
from twisted.internet import defer
 | 
					from twisted.internet import defer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from synapse.types import UserID
 | 
					from synapse.types import UserID, Requester
 | 
				
			||||||
from synapse.api.errors import (
 | 
					from synapse.api.errors import (
 | 
				
			||||||
    AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
 | 
					    AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -360,7 +360,8 @@ class RegistrationHandler(BaseHandler):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @defer.inlineCallbacks
 | 
					    @defer.inlineCallbacks
 | 
				
			||||||
    def get_or_create_user(self, localpart, displayname, duration_seconds):
 | 
					    def get_or_create_user(self, localpart, displayname, duration_seconds):
 | 
				
			||||||
        """Creates a new user or returns an access token for an existing one
 | 
					        """Creates a new user if the user does not exist,
 | 
				
			||||||
 | 
					        else revokes all previous access tokens and generates a new one.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            localpart : The local part of the user ID to register. If None,
 | 
					            localpart : The local part of the user ID to register. If None,
 | 
				
			||||||
| 
						 | 
					@ -399,14 +400,14 @@ class RegistrationHandler(BaseHandler):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            yield registered_user(self.distributor, user)
 | 
					            yield registered_user(self.distributor, user)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            yield self.store.flush_user(user_id=user_id)
 | 
					            yield self.store.user_delete_access_tokens(user_id=user_id)
 | 
				
			||||||
            yield self.store.add_access_token_to_user(user_id=user_id, token=token)
 | 
					            yield self.store.add_access_token_to_user(user_id=user_id, token=token)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if displayname is not None:
 | 
					        if displayname is not None:
 | 
				
			||||||
            logger.info("setting user display name: %s -> %s", user_id, displayname)
 | 
					            logger.info("setting user display name: %s -> %s", user_id, displayname)
 | 
				
			||||||
            profile_handler = self.hs.get_handlers().profile_handler
 | 
					            profile_handler = self.hs.get_handlers().profile_handler
 | 
				
			||||||
            yield profile_handler.set_displayname(
 | 
					            yield profile_handler.set_displayname(
 | 
				
			||||||
                user, user, displayname
 | 
					                user, Requester(user, token, False), displayname
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        defer.returnValue((user_id, token))
 | 
					        defer.returnValue((user_id, token))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,6 +17,7 @@ from twisted.internet import defer
 | 
				
			||||||
from .. import unittest
 | 
					from .. import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from synapse.handlers.register import RegistrationHandler
 | 
					from synapse.handlers.register import RegistrationHandler
 | 
				
			||||||
 | 
					from synapse.types import UserID
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from tests.utils import setup_test_homeserver
 | 
					from tests.utils import setup_test_homeserver
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -36,25 +37,21 @@ class RegistrationTestCase(unittest.TestCase):
 | 
				
			||||||
        self.mock_distributor = Mock()
 | 
					        self.mock_distributor = Mock()
 | 
				
			||||||
        self.mock_distributor.declare("registered_user")
 | 
					        self.mock_distributor.declare("registered_user")
 | 
				
			||||||
        self.mock_captcha_client = Mock()
 | 
					        self.mock_captcha_client = Mock()
 | 
				
			||||||
        hs = yield setup_test_homeserver(
 | 
					        self.hs = yield setup_test_homeserver(
 | 
				
			||||||
            handlers=None,
 | 
					            handlers=None,
 | 
				
			||||||
            http_client=None,
 | 
					            http_client=None,
 | 
				
			||||||
            expire_access_token=True)
 | 
					            expire_access_token=True)
 | 
				
			||||||
        hs.handlers = RegistrationHandlers(hs)
 | 
					        self.hs.handlers = RegistrationHandlers(self.hs)
 | 
				
			||||||
        self.handler = hs.get_handlers().registration_handler
 | 
					        self.handler = self.hs.get_handlers().registration_handler
 | 
				
			||||||
        hs.get_handlers().profile_handler = Mock()
 | 
					        self.hs.get_handlers().profile_handler = Mock()
 | 
				
			||||||
        self.mock_handler = Mock(spec=[
 | 
					        self.mock_handler = Mock(spec=[
 | 
				
			||||||
            "generate_short_term_login_token",
 | 
					            "generate_short_term_login_token",
 | 
				
			||||||
        ])
 | 
					        ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        hs.get_handlers().auth_handler = self.mock_handler
 | 
					        self.hs.get_handlers().auth_handler = self.mock_handler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @defer.inlineCallbacks
 | 
					    @defer.inlineCallbacks
 | 
				
			||||||
    def test_user_is_created_and_logged_in_if_doesnt_exist(self):
 | 
					    def test_user_is_created_and_logged_in_if_doesnt_exist(self):
 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            The user doess not exist in this case so it will register and log it in
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        duration_ms = 200
 | 
					        duration_ms = 200
 | 
				
			||||||
        local_part = "someone"
 | 
					        local_part = "someone"
 | 
				
			||||||
        display_name = "someone"
 | 
					        display_name = "someone"
 | 
				
			||||||
| 
						 | 
					@ -65,3 +62,22 @@ class RegistrationTestCase(unittest.TestCase):
 | 
				
			||||||
            local_part, display_name, duration_ms)
 | 
					            local_part, display_name, duration_ms)
 | 
				
			||||||
        self.assertEquals(result_user_id, user_id)
 | 
					        self.assertEquals(result_user_id, user_id)
 | 
				
			||||||
        self.assertEquals(result_token, 'secret')
 | 
					        self.assertEquals(result_token, 'secret')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @defer.inlineCallbacks
 | 
				
			||||||
 | 
					    def test_if_user_exists(self):
 | 
				
			||||||
 | 
					        store = self.hs.get_datastore()
 | 
				
			||||||
 | 
					        frank = UserID.from_string("@frank:test")
 | 
				
			||||||
 | 
					        yield store.register(
 | 
				
			||||||
 | 
					            user_id=frank.to_string(),
 | 
				
			||||||
 | 
					            token="jkv;g498752-43gj['eamb!-5",
 | 
				
			||||||
 | 
					            password_hash=None)
 | 
				
			||||||
 | 
					        duration_ms = 200
 | 
				
			||||||
 | 
					        local_part = "frank"
 | 
				
			||||||
 | 
					        display_name = "Frank"
 | 
				
			||||||
 | 
					        user_id = "@frank:test"
 | 
				
			||||||
 | 
					        mock_token = self.mock_handler.generate_short_term_login_token
 | 
				
			||||||
 | 
					        mock_token.return_value = 'secret'
 | 
				
			||||||
 | 
					        result_user_id, result_token = yield self.handler.get_or_create_user(
 | 
				
			||||||
 | 
					            local_part, display_name, duration_ms)
 | 
				
			||||||
 | 
					        self.assertEquals(result_user_id, user_id)
 | 
				
			||||||
 | 
					        self.assertEquals(result_token, 'secret')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue