diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 05b6430343..3f51d67821 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -33,7 +33,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) -from synapse.types import JsonDict, ThirdPartyInstanceID +from synapse.types import JsonDict, PublicRoom, ThirdPartyInstanceID from synapse.util import filter_none from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.response_cache import ResponseCache @@ -156,62 +156,99 @@ class RoomListHandler: # plus the direction (forwards or backwards). Next batch tokens always # go forwards, prev batch tokens always go backwards. + forwards = True + last_joined_members = None + last_room_id = None + last_module_index = None if since_token: batch_token = RoomListNextBatch.from_token(since_token) forwards = batch_token.direction_is_forward - else: - batch_token = None - forwards = True + last_joined_members = batch_token.last_joined_members + last_room_id = batch_token.last_room_id + last_module_index = batch_token.last_module_index # we request one more than wanted to see if there are more pages to come probing_limit = limit + 1 if limit is not None else None - results = await self.store.get_largest_public_rooms( - network_tuple, - search_filter, - probing_limit, - bounds=(batch_token.last_joined_members, batch_token.last_room_id) - if batch_token - else None, - forwards=forwards, - ignore_non_federatable=bool(from_remote_server_name), - ) + results = [] - for ( - fetch_public_rooms - ) in self._module_api_callbacks.fetch_public_rooms_callbacks: + print(f"last_module_index {last_module_index}") + print(f"last_room_id {last_room_id}") + + def insert_into_result(new_room: PublicRoom, module_index: Optional[int]): + # print(f"insert {new_room.room_id} {module_index}") + if new_room.num_joined_members == last_joined_members: + if last_module_index is not None and last_room_id is not None: + if module_index is not None and module_index > last_module_index: + return + inserted = False + for i, r in enumerate(results): + r = results[i] + if ( + forwards and new_room.num_joined_members >= r.num_joined_members + ) or ( + not forwards and new_room.num_joined_members <= r.num_joined_members + ): + results.insert(i, new_room) + inserted = True + return + if not inserted: + if forwards: + results.append(new_room) + else: + results.insert(0, new_room) + + room_ids_to_module_index = {} + + for module_index, fetch_public_rooms in enumerate( + self._module_api_callbacks.fetch_public_rooms_callbacks + ): # Ask each module for a list of public rooms given the last_joined_members # value from the since token and the probing limit. + module_last_joined_members = None + if last_joined_members is not None: + module_last_joined_members = last_joined_members + if last_module_index is not None and last_module_index < module_index: + module_last_joined_members = module_last_joined_members - 1 module_public_rooms = await fetch_public_rooms( network_tuple, search_filter, probing_limit, - (batch_token.last_joined_members, batch_token.last_room_id) - if batch_token - else None, + ( + module_last_joined_members, + last_room_id if last_module_index == module_index else None, + ), forwards, ) + + print([r.room_id for r in module_public_rooms]) + + # We reverse for iteration to keep the order in the final list + # since we preprend when inserting module_public_rooms.reverse() # Insert the module's reported public rooms into the list for new_room in module_public_rooms: - inserted = False - for i in range(len(results)): - r = results[i] - if ( - forwards and new_room.num_joined_members >= r.num_joined_members - ) or ( - not forwards - and new_room.num_joined_members <= r.num_joined_members - ): - results.insert(i, new_room) - inserted = True - break - if not inserted: - if forwards: - results.append(new_room) - else: - results.insert(0, new_room) + room_ids_to_module_index[new_room.room_id] = module_index + insert_into_result(new_room, module_index) + + local_public_rooms = await self.store.get_largest_public_rooms( + network_tuple, + search_filter, + probing_limit, + bounds=( + last_joined_members, + last_room_id if last_module_index == None else None, + ), + forwards=forwards, + ignore_non_federatable=bool(from_remote_server_name), + ) + + for r in local_public_rooms: + insert_into_result(r, None) + + # print("final") + # print([r.room_id for r in results]) response: JsonDict = {} num_results = len(results) @@ -231,13 +268,16 @@ class RoomListHandler: initial_entry = results[0] if forwards: - if batch_token is not None: + if since_token is not None: # If there was a token given then we assume that there # must be previous results. response["prev_batch"] = RoomListNextBatch( last_joined_members=initial_entry.num_joined_members, last_room_id=initial_entry.room_id, direction_is_forward=False, + last_module_index=room_ids_to_module_index.get( + initial_entry.room_id + ), ).to_token() if more_to_come: @@ -245,13 +285,19 @@ class RoomListHandler: last_joined_members=final_entry.num_joined_members, last_room_id=final_entry.room_id, direction_is_forward=True, + last_module_index=room_ids_to_module_index.get( + final_entry.room_id + ), ).to_token() else: - if batch_token is not None: + if since_token is not None: response["next_batch"] = RoomListNextBatch( last_joined_members=final_entry.num_joined_members, last_room_id=final_entry.room_id, direction_is_forward=True, + last_module_index=room_ids_to_module_index.get( + final_entry.room_id + ), ).to_token() if more_to_come: @@ -259,6 +305,9 @@ class RoomListHandler: last_joined_members=initial_entry.num_joined_members, last_room_id=initial_entry.room_id, direction_is_forward=False, + last_module_index=room_ids_to_module_index.get( + initial_entry.room_id + ), ).to_token() response["chunk"] = [attr.asdict(r, filter=filter_none) for r in results] @@ -507,11 +556,13 @@ class RoomListNextBatch: last_joined_members: int # The count to get rooms after/before last_room_id: str # The room_id to get rooms after/before direction_is_forward: bool # True if this is a next_batch, false if prev_batch + last_module_index: Optional[int] = None KEY_DICT = { "last_joined_members": "m", "last_room_id": "r", "direction_is_forward": "d", + "last_module_index": "i", } REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()} @@ -524,6 +575,7 @@ class RoomListNextBatch: ) def to_token(self) -> str: + # print(self) return encode_base64( msgpack.dumps( {self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()} diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 02bd76372e..0dcce4fc18 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -372,7 +372,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): network_tuple: Optional[ThirdPartyInstanceID], search_filter: Optional[dict], limit: Optional[int], - bounds: Optional[Tuple[int, str]], + bounds: Tuple[Optional[int], Optional[str]], forwards: bool, ignore_non_federatable: bool = False, ) -> List[PublicRoom]: @@ -420,26 +420,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): # Work out the bounds if we're given them, these bounds look slightly # odd, but are designed to help query planner use indices by pulling # out a common bound. - if bounds: - last_joined_members, last_room_id = bounds - if forwards: - where_clauses.append( - """ - joined_members <= ? AND ( - joined_members < ? OR room_id < ? - ) - """ - ) - else: - where_clauses.append( - """ - joined_members >= ? AND ( - joined_members > ? OR room_id > ? - ) - """ - ) + last_joined_members, last_room_id = bounds + if last_joined_members is not None: + comp = "<" if forwards else ">" - query_args += [last_joined_members, last_joined_members, last_room_id] + clause = f"joined_members {comp}= ? AND (joined_members {comp} ?" + query_args += [last_joined_members, last_joined_members] + + if last_room_id is None: + clause += ")" + else: + clause += f"OR room_id {comp} ?)" + query_args.append(last_room_id) + + where_clauses.append(clause) if ignore_non_federatable: where_clauses.append("is_federatable") diff --git a/tests/module_api/test_fetch_public_rooms.py b/tests/module_api/test_fetch_public_rooms.py index dd9bbc1487..f27523d63f 100644 --- a/tests/module_api/test_fetch_public_rooms.py +++ b/tests/module_api/test_fetch_public_rooms.py @@ -49,106 +49,218 @@ class FetchPublicRoomsTestCase(HomeserverTestCase): network_tuple: Optional[ThirdPartyInstanceID], search_filter: Optional[dict], limit: Optional[int], - bounds: Optional[Tuple[int, str]], + bounds: Tuple[Optional[int], Optional[str]], forwards: bool, ) -> List[PublicRoom]: room1 = PublicRoom( - room_id="!test1:test", + room_id="!one_members:module1", num_joined_members=1, world_readable=True, guest_can_join=False, ) room3 = PublicRoom( - room_id="!test3:test", + room_id="!three_members:module1", num_joined_members=3, world_readable=True, guest_can_join=False, ) room3_2 = PublicRoom( - room_id="!test3_2:test", + room_id="!three_members_2:module1", num_joined_members=3, world_readable=True, guest_can_join=False, ) - if forwards: - if limit == 2: - if bounds is None: - return [room3_2, room3] - (last_joined_members, last_room_id) = bounds - if last_joined_members == 3: - if last_room_id == room3_2.room_id: - return [room3, room1] - if last_room_id == room3.room_id: - return [room1] - elif last_joined_members < 3: - return [room1] - return [room3_2, room3, room1] - else: - if limit == 2 and bounds is not None: - (last_joined_members, last_room_id) = bounds - if last_joined_members == 3: - if last_room_id == room3.room_id: - return [room3_2] - return [room1, room3, room3_2] + (last_joined_members, last_room_id) = bounds + print(f"cb {forwards} {bounds}") + + result = [room1, room3, room3_2] + + if last_joined_members is not None: + if forwards: + result = list( + filter( + lambda r: r.num_joined_members <= last_joined_members, + result, + ) + ) + else: + result = list( + filter( + lambda r: r.num_joined_members >= last_joined_members, + result, + ) + ) + + print([r.room_id for r in result]) + + if last_room_id is not None: + new_res = [] + for r in result: + if r.room_id == last_room_id: + break + new_res.append(r) + result = new_res + + if forwards: + result.reverse() + + if limit is not None: + result = result[:limit] + + return result + + # if forwards: + # if limit == 2: + # if last_joined_members is None: + # return [room3_2, room3] + # elif last_joined_members == 3: + # if last_room_id == room3_2.room_id: + # return [room3, room1] + # if last_room_id == room3.room_id: + # return [room1] + # elif last_joined_members < 3: + # return [room1] + # return [room3_2, room3, room1] + # else: + # if ( + # limit == 2 + # and last_joined_members == 3 + # and last_room_id == room3.room_id + # ): + # return [room3_2] + # return [room1, room3, room3_2] + + async def cb2( + network_tuple: Optional[ThirdPartyInstanceID], + search_filter: Optional[dict], + limit: Optional[int], + bounds: Tuple[Optional[int], Optional[str]], + forwards: bool, + ) -> List[PublicRoom]: + room3 = PublicRoom( + room_id="!three_members:module2", + num_joined_members=3, + world_readable=True, + guest_can_join=False, + ) + + result = [room3] + + (last_joined_members, last_room_id) = bounds + + print(f"cb2 {forwards} {bounds}") + + if last_joined_members is not None: + if forwards: + result = list( + filter( + lambda r: r.num_joined_members <= last_joined_members, + result, + ) + ) + else: + result = list( + filter( + lambda r: r.num_joined_members >= last_joined_members, + result, + ) + ) + + print([r.room_id for r in result]) + + if last_room_id is not None: + new_res = [] + for r in result: + if r.room_id == last_room_id: + break + new_res.append(r) + result = new_res + + if forwards: + result.reverse() + + if limit is not None: + result = result[:limit] + + return result + + self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb2) self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb) user = self.register_user("alice", "pass") token = self.login(user, "pass") - # Create a room + user2 = self.register_user("alice2", "pass") + token2 = self.login(user2, "pass") + + user3 = self.register_user("alice3", "pass") + token3 = self.login(user3, "pass") + + # Create a room with 2 people room_id = self.helper.create_room_as( user, is_public=True, extra_content={"visibility": "public"}, tok=token, ) - - user2 = self.register_user("alice2", "pass") - token2 = self.login(user2, "pass") self.helper.join(room_id, user2, tok=token2) + # Create a room with 3 people + room_id = self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=token, + ) + self.helper.join(room_id, user2, tok=token2) + self.helper.join(room_id, user3, tok=token3) + def test_no_limit(self) -> None: channel = self.make_request("GET", self.url) - self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + chunk = channel.json_body["chunk"] - self.assertEquals(len(channel.json_body["chunk"]), 4) - self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3) - self.assertEquals(channel.json_body["chunk"][1]["num_joined_members"], 3) - self.assertEquals(channel.json_body["chunk"][2]["num_joined_members"], 2) - self.assertEquals(channel.json_body["chunk"][3]["num_joined_members"], 1) + self.assertEquals(len(chunk), 6) + for i in range(4): + self.assertEquals(chunk[i]["num_joined_members"], 3) + self.assertEquals(chunk[4]["num_joined_members"], 2) + self.assertEquals(chunk[5]["num_joined_members"], 1) def test_pagination(self) -> None: - channel = self.make_request("GET", self.url + "?limit=1") - self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3) - returned_room3_id = channel.json_body["chunk"][0]["room_id"] + returned_three_members_rooms = set() + + next_batch = None + for i in range(4): + since_query_str = f"&since={next_batch}" if next_batch else "" + channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 3) + self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms) + returned_three_members_rooms.add(chunk[0]["room_id"]) + next_batch = channel.json_body["next_batch"] + + channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 2) next_batch = channel.json_body["next_batch"] channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}") - self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3) - # We should get the other room with 3 users here - self.assertNotEquals( - returned_room3_id, channel.json_body["chunk"][0]["room_id"] - ) + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 1) prev_batch = channel.json_body["prev_batch"] - channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}") - self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3) - self.assertEquals(returned_room3_id, channel.json_body["chunk"][0]["room_id"]) - next_batch = channel.json_body["next_batch"] + # channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}") + # chunk = channel.json_body["chunk"] + # print(chunk) + # self.assertEquals(chunk[0]["num_joined_members"], 2) + # prev_batch = channel.json_body["prev_batch"] - # We went backwards once, so we should get same result as step 2 - channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}") - self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3) - self.assertNotEquals( - returned_room3_id, channel.json_body["chunk"][0]["room_id"] - ) - next_batch = channel.json_body["next_batch"] - - channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}") - self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 2) - next_batch = channel.json_body["next_batch"] - - channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}") - self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 1) + # returned_three_members_rooms = set() + # for i in range(4): + # channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}") + # chunk = channel.json_body["chunk"] + # self.assertEquals(chunk[0]["num_joined_members"], 3) + # self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms) + # returned_three_members_rooms.add(chunk[0]["room_id"]) + # prev_batch = channel.json_body["prev_batch"]