diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 2f66f5abb5..1f37b57373 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -185,7 +185,9 @@ class RoomListHandler: module_public_rooms = await fetch_public_rooms( forwards, probing_limit, - batch_token.last_joined_members if batch_token else None, + (batch_token.last_joined_members, batch_token.last_room_id) + if batch_token + else None, ) # Insert the module's reported public rooms into the list @@ -212,7 +214,7 @@ class RoomListHandler: response: JsonDict = {} num_results = len(results) if limit is not None: - more_to_come = num_results == probing_limit + more_to_come = num_results >= probing_limit # Depending on direction we trim either the front or back. if forwards: diff --git a/synapse/module_api/callbacks/public_rooms_callbacks.py b/synapse/module_api/callbacks/public_rooms_callbacks.py index 969b1b1faa..847becc89b 100644 --- a/synapse/module_api/callbacks/public_rooms_callbacks.py +++ b/synapse/module_api/callbacks/public_rooms_callbacks.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) # Types for callbacks to be registered via the module api FETCH_PUBLIC_ROOMS_CALLBACK = Callable[ - [bool, Optional[int], Optional[int]], + [bool, Optional[int], Optional[Tuple[int, str]]], Awaitable[Iterable[PublicRoom]], ] diff --git a/tests/module_api/test_fetch_public_rooms_cb.py b/tests/module_api/test_fetch_public_rooms.py similarity index 53% rename from tests/module_api/test_fetch_public_rooms_cb.py rename to tests/module_api/test_fetch_public_rooms.py index a62764ab72..1486fbfab9 100644 --- a/tests/module_api/test_fetch_public_rooms_cb.py +++ b/tests/module_api/test_fetch_public_rooms.py @@ -21,6 +21,7 @@ from typing import ( List, Optional, TypeVar, + Tuple, cast, Iterable, ) @@ -34,7 +35,7 @@ from synapse.util import Clock from tests.unittest import HomeserverTestCase -class FetchPublicRoomsCbTestCase(HomeserverTestCase): +class FetchPublicRoomsTestCase(HomeserverTestCase): servlets = [ admin.register_servlets, login.register_servlets, @@ -45,7 +46,7 @@ class FetchPublicRoomsCbTestCase(HomeserverTestCase): config = self.default_config() config["allow_public_rooms_without_auth"] = True self.hs = self.setup_test_homeserver(config=config) - self.url = b"/_matrix/client/r0/publicRooms" + self.url = "/_matrix/client/r0/publicRooms" return self.hs @@ -56,22 +57,26 @@ class FetchPublicRoomsCbTestCase(HomeserverTestCase): self._module_api = homeserver.get_module_api() async def cb( - forwards: bool, limit: Optional[int], last_joined_members: Optional[int] + forwards: bool, limit: Optional[int], bounds: Optional[Tuple[int, str]] ) -> Iterable[PublicRoom]: - return [ - PublicRoom( - room_id="!test1:test", - num_joined_members=1, - world_readable=True, - guest_can_join=False, - ), - PublicRoom( - room_id="!test3:test", - num_joined_members=3, - world_readable=True, - guest_can_join=False, - ) - ] + room1 = PublicRoom( + room_id="!test1:test", + num_joined_members=1, + world_readable=True, + guest_can_join=False, + ) + room3 = PublicRoom( + room_id="!test3:test", + num_joined_members=3, + world_readable=True, + guest_can_join=False, + ) + if limit is not None and limit < 3 and bounds is not None: + (last_joined_members, last_room_id) = bounds + if last_joined_members < 3 or last_room_id == room3["room_id"]: + return [room1] + + return [room3, room1] self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb) @@ -90,7 +95,28 @@ class FetchPublicRoomsCbTestCase(HomeserverTestCase): token2 = self.login(user2, "pass") self.helper.join(room_id, user2, tok=token2) - def test_public_rooms(self) -> None: + def test_no_limit(self) -> None: channel = self.make_request("GET", self.url) - print(channel.text_body) self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + self.assertEquals(len(channel.json_body["chunk"]), 3) + self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3) + self.assertEquals(channel.json_body["chunk"][1]["num_joined_members"], 2) + self.assertEquals(channel.json_body["chunk"][2]["num_joined_members"], 1) + + def test_pagination(self) -> None: + channel = self.make_request("GET", self.url + "?limit=1") + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3) + + channel = self.make_request( + "GET", self.url + "?limit=1&since=" + channel.json_body["next_batch"] + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 2) + + channel = self.make_request( + "GET", self.url + "?limit=1&since=" + channel.json_body["next_batch"] + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 1) \ No newline at end of file