anoa/public_rooms_module_api
Mathieu Velten 2023-05-23 11:59:31 +02:00
parent 7709a99e6f
commit 50d75a311b
4 changed files with 74 additions and 55 deletions

View File

@ -183,11 +183,13 @@ class RoomListHandler:
# Ask each module for a list of public rooms given the last_joined_members
# value from the since token and the probing limit.
module_public_rooms = await fetch_public_rooms(
forwards,
network_tuple,
search_filter,
probing_limit,
(batch_token.last_joined_members, batch_token.last_room_id)
if batch_token
else None,
forwards,
)
module_public_rooms.reverse()

View File

@ -15,14 +15,20 @@
import logging
from typing import Awaitable, Callable, List, Optional, Tuple
from synapse.types import PublicRoom
from synapse.types import PublicRoom, ThirdPartyInstanceID
logger = logging.getLogger(__name__)
# Types for callbacks to be registered via the module api
FETCH_PUBLIC_ROOMS_CALLBACK = Callable[
[bool, Optional[int], Optional[Tuple[int, str]]],
[
Optional[ThirdPartyInstanceID], # network_tuple
Optional[dict], # search_filter
Optional[int], # limit
Optional[Tuple[int, str]], # bounds
bool, # forwards
],
Awaitable[List[PublicRoom]],
]

View File

@ -382,7 +382,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
Args:
network_tuple
search_filter
limit: Maxmimum number of rows to return, unlimited otherwise.
limit: Maximum number of rows to return, unlimited otherwise.
bounds: An uppoer or lower bound to apply to result set if given,
consists of a joined member count and room_id (these are
excluded from result set).

View File

@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin, login, room
from synapse.server import HomeServer
from synapse.types import PublicRoom
from synapse.types import PublicRoom, ThirdPartyInstanceID
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@ -46,47 +46,51 @@ class FetchPublicRoomsTestCase(HomeserverTestCase):
self._module_api = homeserver.get_module_api()
async def cb(
forwards: bool, limit: Optional[int], bounds: Optional[Tuple[int, str]]
network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
limit: Optional[int],
bounds: Optional[Tuple[int, str]],
forwards: bool,
) -> List[PublicRoom]:
rooms_db = [
PublicRoom(
room_id="!test1:test",
num_joined_members=1,
world_readable=True,
guest_can_join=False,
),
PublicRoom(
room_id="!test3:test",
num_joined_members=3,
world_readable=True,
guest_can_join=False,
),
PublicRoom(
room_id="!test3_2:test",
num_joined_members=3,
world_readable=True,
guest_can_join=False,
),
]
result = []
if limit is not None and bounds is not None:
(last_joined_members, last_room_id) = bounds
for r in rooms_db:
if r.num_joined_members <= last_joined_members:
if r.room_id == last_room_id:
break
result.append(r)
else:
result = rooms_db
room1 = PublicRoom(
room_id="!test1:test",
num_joined_members=1,
world_readable=True,
guest_can_join=False,
)
room3 = PublicRoom(
room_id="!test3:test",
num_joined_members=3,
world_readable=True,
guest_can_join=False,
)
room3_2 = PublicRoom(
room_id="!test3_2:test",
num_joined_members=3,
world_readable=True,
guest_can_join=False,
)
if forwards:
result.reverse()
if limit is not None:
result = result[:limit]
return result
if limit == 2:
if bounds is None:
return [room3_2, room3]
(last_joined_members, last_room_id) = bounds
if last_joined_members == 3:
if last_room_id == room3_2.room_id:
return [room3, room1]
if last_room_id == room3.room_id:
return [room1]
elif last_joined_members < 3:
return [room1]
return [room3_2, room3, room1]
else:
if limit == 2 and bounds is not None:
(last_joined_members, last_room_id) = bounds
if last_joined_members == 3:
if last_room_id == room3.room_id:
return [room3_2]
return [room1, room3, room3_2]
self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb)
@ -117,27 +121,34 @@ class FetchPublicRoomsTestCase(HomeserverTestCase):
def test_pagination(self) -> None:
channel = self.make_request("GET", self.url + "?limit=1")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
returned_room3_id = channel.json_body["chunk"][0]["room_id"]
next_batch = channel.json_body["next_batch"]
channel = self.make_request(
"GET", self.url + "?limit=1&since=" + channel.json_body["next_batch"]
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
# We should get the other room with 3 users here
self.assertNotEquals(
returned_room3_id, channel.json_body["chunk"][0]["room_id"]
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
prev_batch = channel.json_body["prev_batch"]
channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
self.assertEquals(returned_room3_id, channel.json_body["chunk"][0]["room_id"])
next_batch = channel.json_body["next_batch"]
# We went backwards once, so we should get same result as step 2
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
self.assertNotEquals(
returned_room3_id, channel.json_body["chunk"][0]["room_id"]
)
next_batch = channel.json_body["next_batch"]
channel = self.make_request(
"GET", self.url + "?limit=1&since=" + channel.json_body["next_batch"]
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 2)
next_batch = channel.json_body["next_batch"]
channel = self.make_request(
"GET", self.url + "?limit=1&since=" + channel.json_body["next_batch"]
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 1)