anoa/public_rooms_module_api
Mathieu Velten 2023-05-23 11:59:31 +02:00
parent 7709a99e6f
commit 50d75a311b
4 changed files with 74 additions and 55 deletions

View File

@ -183,11 +183,13 @@ class RoomListHandler:
# Ask each module for a list of public rooms given the last_joined_members # Ask each module for a list of public rooms given the last_joined_members
# value from the since token and the probing limit. # value from the since token and the probing limit.
module_public_rooms = await fetch_public_rooms( module_public_rooms = await fetch_public_rooms(
forwards, network_tuple,
search_filter,
probing_limit, probing_limit,
(batch_token.last_joined_members, batch_token.last_room_id) (batch_token.last_joined_members, batch_token.last_room_id)
if batch_token if batch_token
else None, else None,
forwards,
) )
module_public_rooms.reverse() module_public_rooms.reverse()

View File

@ -15,14 +15,20 @@
import logging import logging
from typing import Awaitable, Callable, List, Optional, Tuple from typing import Awaitable, Callable, List, Optional, Tuple
from synapse.types import PublicRoom from synapse.types import PublicRoom, ThirdPartyInstanceID
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Types for callbacks to be registered via the module api # Types for callbacks to be registered via the module api
FETCH_PUBLIC_ROOMS_CALLBACK = Callable[ FETCH_PUBLIC_ROOMS_CALLBACK = Callable[
[bool, Optional[int], Optional[Tuple[int, str]]], [
Optional[ThirdPartyInstanceID], # network_tuple
Optional[dict], # search_filter
Optional[int], # limit
Optional[Tuple[int, str]], # bounds
bool, # forwards
],
Awaitable[List[PublicRoom]], Awaitable[List[PublicRoom]],
] ]

View File

@ -382,7 +382,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
Args: Args:
network_tuple network_tuple
search_filter search_filter
limit: Maxmimum number of rows to return, unlimited otherwise. limit: Maximum number of rows to return, unlimited otherwise.
bounds: An uppoer or lower bound to apply to result set if given, bounds: An uppoer or lower bound to apply to result set if given,
consists of a joined member count and room_id (these are consists of a joined member count and room_id (these are
excluded from result set). excluded from result set).

View File

@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin, login, room from synapse.rest import admin, login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import PublicRoom from synapse.types import PublicRoom, ThirdPartyInstanceID
from synapse.util import Clock from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -46,47 +46,51 @@ class FetchPublicRoomsTestCase(HomeserverTestCase):
self._module_api = homeserver.get_module_api() self._module_api = homeserver.get_module_api()
async def cb( async def cb(
forwards: bool, limit: Optional[int], bounds: Optional[Tuple[int, str]] network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
limit: Optional[int],
bounds: Optional[Tuple[int, str]],
forwards: bool,
) -> List[PublicRoom]: ) -> List[PublicRoom]:
rooms_db = [ room1 = PublicRoom(
PublicRoom( room_id="!test1:test",
room_id="!test1:test", num_joined_members=1,
num_joined_members=1, world_readable=True,
world_readable=True, guest_can_join=False,
guest_can_join=False, )
), room3 = PublicRoom(
PublicRoom( room_id="!test3:test",
room_id="!test3:test", num_joined_members=3,
num_joined_members=3, world_readable=True,
world_readable=True, guest_can_join=False,
guest_can_join=False, )
), room3_2 = PublicRoom(
PublicRoom( room_id="!test3_2:test",
room_id="!test3_2:test", num_joined_members=3,
num_joined_members=3, world_readable=True,
world_readable=True, guest_can_join=False,
guest_can_join=False, )
),
]
result = []
if limit is not None and bounds is not None:
(last_joined_members, last_room_id) = bounds
for r in rooms_db:
if r.num_joined_members <= last_joined_members:
if r.room_id == last_room_id:
break
result.append(r)
else:
result = rooms_db
if forwards: if forwards:
result.reverse() if limit == 2:
if bounds is None:
if limit is not None: return [room3_2, room3]
result = result[:limit] (last_joined_members, last_room_id) = bounds
if last_joined_members == 3:
return result if last_room_id == room3_2.room_id:
return [room3, room1]
if last_room_id == room3.room_id:
return [room1]
elif last_joined_members < 3:
return [room1]
return [room3_2, room3, room1]
else:
if limit == 2 and bounds is not None:
(last_joined_members, last_room_id) = bounds
if last_joined_members == 3:
if last_room_id == room3.room_id:
return [room3_2]
return [room1, room3, room3_2]
self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb) self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb)
@ -117,27 +121,34 @@ class FetchPublicRoomsTestCase(HomeserverTestCase):
def test_pagination(self) -> None: def test_pagination(self) -> None:
channel = self.make_request("GET", self.url + "?limit=1") channel = self.make_request("GET", self.url + "?limit=1")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3) self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
returned_room3_id = channel.json_body["chunk"][0]["room_id"] returned_room3_id = channel.json_body["chunk"][0]["room_id"]
next_batch = channel.json_body["next_batch"]
channel = self.make_request( channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
"GET", self.url + "?limit=1&since=" + channel.json_body["next_batch"] self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
# We should get the other room with 3 users here
self.assertNotEquals(
returned_room3_id, channel.json_body["chunk"][0]["room_id"]
) )
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) prev_batch = channel.json_body["prev_batch"]
channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
self.assertEquals(returned_room3_id, channel.json_body["chunk"][0]["room_id"])
next_batch = channel.json_body["next_batch"]
# We went backwards once, so we should get same result as step 2
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3) self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
self.assertNotEquals( self.assertNotEquals(
returned_room3_id, channel.json_body["chunk"][0]["room_id"] returned_room3_id, channel.json_body["chunk"][0]["room_id"]
) )
next_batch = channel.json_body["next_batch"]
channel = self.make_request( channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
"GET", self.url + "?limit=1&since=" + channel.json_body["next_batch"]
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 2) self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 2)
next_batch = channel.json_body["next_batch"]
channel = self.make_request( channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
"GET", self.url + "?limit=1&since=" + channel.json_body["next_batch"]
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 1) self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 1)