Fix stuffs

anoa/public_rooms_module_api
Mathieu Velten 2023-05-19 16:46:02 +02:00
parent e01ea0edc0
commit 74dbcaaab2
3 changed files with 50 additions and 22 deletions

View File

@ -185,7 +185,9 @@ class RoomListHandler:
module_public_rooms = await fetch_public_rooms( module_public_rooms = await fetch_public_rooms(
forwards, forwards,
probing_limit, probing_limit,
batch_token.last_joined_members if batch_token else None, (batch_token.last_joined_members, batch_token.last_room_id)
if batch_token
else None,
) )
# Insert the module's reported public rooms into the list # Insert the module's reported public rooms into the list
@ -212,7 +214,7 @@ class RoomListHandler:
response: JsonDict = {} response: JsonDict = {}
num_results = len(results) num_results = len(results)
if limit is not None: if limit is not None:
more_to_come = num_results == probing_limit more_to_come = num_results >= probing_limit
# Depending on direction we trim either the front or back. # Depending on direction we trim either the front or back.
if forwards: if forwards:

View File

@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
# Types for callbacks to be registered via the module api # Types for callbacks to be registered via the module api
FETCH_PUBLIC_ROOMS_CALLBACK = Callable[ FETCH_PUBLIC_ROOMS_CALLBACK = Callable[
[bool, Optional[int], Optional[int]], [bool, Optional[int], Optional[Tuple[int, str]]],
Awaitable[Iterable[PublicRoom]], Awaitable[Iterable[PublicRoom]],
] ]

View File

@ -21,6 +21,7 @@ from typing import (
List, List,
Optional, Optional,
TypeVar, TypeVar,
Tuple,
cast, cast,
Iterable, Iterable,
) )
@ -34,7 +35,7 @@ from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
class FetchPublicRoomsCbTestCase(HomeserverTestCase): class FetchPublicRoomsTestCase(HomeserverTestCase):
servlets = [ servlets = [
admin.register_servlets, admin.register_servlets,
login.register_servlets, login.register_servlets,
@ -45,7 +46,7 @@ class FetchPublicRoomsCbTestCase(HomeserverTestCase):
config = self.default_config() config = self.default_config()
config["allow_public_rooms_without_auth"] = True config["allow_public_rooms_without_auth"] = True
self.hs = self.setup_test_homeserver(config=config) self.hs = self.setup_test_homeserver(config=config)
self.url = b"/_matrix/client/r0/publicRooms" self.url = "/_matrix/client/r0/publicRooms"
return self.hs return self.hs
@ -56,22 +57,26 @@ class FetchPublicRoomsCbTestCase(HomeserverTestCase):
self._module_api = homeserver.get_module_api() self._module_api = homeserver.get_module_api()
async def cb( async def cb(
forwards: bool, limit: Optional[int], last_joined_members: Optional[int] forwards: bool, limit: Optional[int], bounds: Optional[Tuple[int, str]]
) -> Iterable[PublicRoom]: ) -> Iterable[PublicRoom]:
return [ room1 = PublicRoom(
PublicRoom( room_id="!test1:test",
room_id="!test1:test", num_joined_members=1,
num_joined_members=1, world_readable=True,
world_readable=True, guest_can_join=False,
guest_can_join=False, )
), room3 = PublicRoom(
PublicRoom( room_id="!test3:test",
room_id="!test3:test", num_joined_members=3,
num_joined_members=3, world_readable=True,
world_readable=True, guest_can_join=False,
guest_can_join=False, )
) if limit is not None and limit < 3 and bounds is not None:
] (last_joined_members, last_room_id) = bounds
if last_joined_members < 3 or last_room_id == room3["room_id"]:
return [room1]
return [room3, room1]
self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb) self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb)
@ -90,7 +95,28 @@ class FetchPublicRoomsCbTestCase(HomeserverTestCase):
token2 = self.login(user2, "pass") token2 = self.login(user2, "pass")
self.helper.join(room_id, user2, tok=token2) self.helper.join(room_id, user2, tok=token2)
def test_public_rooms(self) -> None: def test_no_limit(self) -> None:
channel = self.make_request("GET", self.url) channel = self.make_request("GET", self.url)
print(channel.text_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEquals(len(channel.json_body["chunk"]), 3)
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
self.assertEquals(channel.json_body["chunk"][1]["num_joined_members"], 2)
self.assertEquals(channel.json_body["chunk"][2]["num_joined_members"], 1)
def test_pagination(self) -> None:
channel = self.make_request("GET", self.url + "?limit=1")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
channel = self.make_request(
"GET", self.url + "?limit=1&since=" + channel.json_body["next_batch"]
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 2)
channel = self.make_request(
"GET", self.url + "?limit=1&since=" + channel.json_body["next_batch"]
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 1)