diff --git a/changelog.d/6004.bugfix b/changelog.d/6004.bugfix new file mode 100644 index 0000000000..45c179c8fd --- /dev/null +++ b/changelog.d/6004.bugfix @@ -0,0 +1 @@ +Only count real users when checking for auto-creation of auto-join room. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 975da57ffd..06bd03b77c 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -275,16 +275,12 @@ class RegistrationHandler(BaseHandler): fake_requester = create_requester(user_id) # try to create the room if we're the first real user on the server. Note - # that an auto-generated support user is not a real user and will never be + # that an auto-generated support or bot user is not a real user and will never be # the user to create the room should_auto_create_rooms = False - is_support = yield self.store.is_support_user(user_id) - # There is an edge case where the first user is the support user, then - # the room is never created, though this seems unlikely and - # recoverable from given the support user being involved in the first - # place. - if self.hs.config.autocreate_auto_join_rooms and not is_support: - count = yield self.store.count_all_users() + is_real_user = yield self.store.is_real_user(user_id) + if self.hs.config.autocreate_auto_join_rooms and is_real_user: + count = yield self.store.count_real_users() should_auto_create_rooms = count == 1 for r in self.hs.config.auto_join_rooms: logger.info("Auto-joining %s to %s", user_id, r) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 5138792a5f..109052fa41 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -322,6 +322,19 @@ class RegistrationWorkerStore(SQLBaseStore): return None + @cachedInlineCallbacks() + def is_real_user(self, user_id): + """Determines if the user is a real user, ie does not have a 'user_type'. + + Args: + user_id (str): user id to test + + Returns: + Deferred[bool]: True if user 'user_type' is null or empty string + """ + res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id) + return res + @cachedInlineCallbacks() def is_support_user(self, user_id): """Determines if the user is of type UserTypes.SUPPORT @@ -337,6 +350,16 @@ class RegistrationWorkerStore(SQLBaseStore): ) return res + def is_real_user_txn(self, txn, user_id): + res = self._simple_select_one_onecol_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + retcol="user_type", + allow_none=True, + ) + return res is None + def is_support_user_txn(self, txn, user_id): res = self._simple_select_one_onecol_txn( txn=txn, @@ -421,6 +444,20 @@ class RegistrationWorkerStore(SQLBaseStore): ret = yield self.runInteraction("count_users", _count_users) return ret + @defer.inlineCallbacks + def count_real_users(self): + """Counts all users without a special user_type registered on the homeserver.""" + + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_real_users", _count_users) + return ret + @defer.inlineCallbacks def find_next_generated_user_id_localpart(self): """ diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e10296a5e4..1e9ba3a201 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -171,11 +171,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase): rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) - def test_auto_create_auto_join_rooms_when_support_user_exists(self): + def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_support_user = Mock(return_value=True) + self.store.is_real_user = Mock(return_value=False) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -183,6 +183,31 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias = RoomAlias.from_string(room_alias_str) self.get_failure(directory_handler.get_association(room_alias), SynapseError) + def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + + self.store.count_real_users = Mock(return_value=1) + self.store.is_real_user = Mock(return_value=True) + user_id = self.get_success(self.handler.register_user(localpart="real")) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = self.get_success(directory_handler.get_association(room_alias)) + + self.assertTrue(room_id["room_id"] in rooms) + self.assertEqual(len(rooms), 1) + + def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self): + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + + self.store.count_real_users = Mock(return_value=2) + self.store.is_real_user = Mock(return_value=True) + user_id = self.get_success(self.handler.register_user(localpart="real")) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 0) + def test_auto_create_auto_join_where_no_consent(self): """Test to ensure that the first user is not auto-joined to a room if they have not given general consent.