Fix forwards pagination

anoa/public_rooms_module_api
Mathieu Velten 2023-05-26 18:03:04 +02:00
parent 50d75a311b
commit 3194933c1b
3 changed files with 277 additions and 119 deletions

View File

@ -33,7 +33,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.types import JsonDict, PublicRoom, ThirdPartyInstanceID
from synapse.util import filter_none
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache
@ -156,62 +156,99 @@ class RoomListHandler:
# plus the direction (forwards or backwards). Next batch tokens always
# go forwards, prev batch tokens always go backwards.
forwards = True
last_joined_members = None
last_room_id = None
last_module_index = None
if since_token:
batch_token = RoomListNextBatch.from_token(since_token)
forwards = batch_token.direction_is_forward
else:
batch_token = None
forwards = True
last_joined_members = batch_token.last_joined_members
last_room_id = batch_token.last_room_id
last_module_index = batch_token.last_module_index
# we request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None
results = await self.store.get_largest_public_rooms(
network_tuple,
search_filter,
probing_limit,
bounds=(batch_token.last_joined_members, batch_token.last_room_id)
if batch_token
else None,
forwards=forwards,
ignore_non_federatable=bool(from_remote_server_name),
)
results = []
for (
fetch_public_rooms
) in self._module_api_callbacks.fetch_public_rooms_callbacks:
print(f"last_module_index {last_module_index}")
print(f"last_room_id {last_room_id}")
def insert_into_result(new_room: PublicRoom, module_index: Optional[int]):
# print(f"insert {new_room.room_id} {module_index}")
if new_room.num_joined_members == last_joined_members:
if last_module_index is not None and last_room_id is not None:
if module_index is not None and module_index > last_module_index:
return
inserted = False
for i, r in enumerate(results):
r = results[i]
if (
forwards and new_room.num_joined_members >= r.num_joined_members
) or (
not forwards and new_room.num_joined_members <= r.num_joined_members
):
results.insert(i, new_room)
inserted = True
return
if not inserted:
if forwards:
results.append(new_room)
else:
results.insert(0, new_room)
room_ids_to_module_index = {}
for module_index, fetch_public_rooms in enumerate(
self._module_api_callbacks.fetch_public_rooms_callbacks
):
# Ask each module for a list of public rooms given the last_joined_members
# value from the since token and the probing limit.
module_last_joined_members = None
if last_joined_members is not None:
module_last_joined_members = last_joined_members
if last_module_index is not None and last_module_index < module_index:
module_last_joined_members = module_last_joined_members - 1
module_public_rooms = await fetch_public_rooms(
network_tuple,
search_filter,
probing_limit,
(batch_token.last_joined_members, batch_token.last_room_id)
if batch_token
else None,
(
module_last_joined_members,
last_room_id if last_module_index == module_index else None,
),
forwards,
)
print([r.room_id for r in module_public_rooms])
# We reverse for iteration to keep the order in the final list
# since we preprend when inserting
module_public_rooms.reverse()
# Insert the module's reported public rooms into the list
for new_room in module_public_rooms:
inserted = False
for i in range(len(results)):
r = results[i]
if (
forwards and new_room.num_joined_members >= r.num_joined_members
) or (
not forwards
and new_room.num_joined_members <= r.num_joined_members
):
results.insert(i, new_room)
inserted = True
break
if not inserted:
if forwards:
results.append(new_room)
else:
results.insert(0, new_room)
room_ids_to_module_index[new_room.room_id] = module_index
insert_into_result(new_room, module_index)
local_public_rooms = await self.store.get_largest_public_rooms(
network_tuple,
search_filter,
probing_limit,
bounds=(
last_joined_members,
last_room_id if last_module_index == None else None,
),
forwards=forwards,
ignore_non_federatable=bool(from_remote_server_name),
)
for r in local_public_rooms:
insert_into_result(r, None)
# print("final")
# print([r.room_id for r in results])
response: JsonDict = {}
num_results = len(results)
@ -231,13 +268,16 @@ class RoomListHandler:
initial_entry = results[0]
if forwards:
if batch_token is not None:
if since_token is not None:
# If there was a token given then we assume that there
# must be previous results.
response["prev_batch"] = RoomListNextBatch(
last_joined_members=initial_entry.num_joined_members,
last_room_id=initial_entry.room_id,
direction_is_forward=False,
last_module_index=room_ids_to_module_index.get(
initial_entry.room_id
),
).to_token()
if more_to_come:
@ -245,13 +285,19 @@ class RoomListHandler:
last_joined_members=final_entry.num_joined_members,
last_room_id=final_entry.room_id,
direction_is_forward=True,
last_module_index=room_ids_to_module_index.get(
final_entry.room_id
),
).to_token()
else:
if batch_token is not None:
if since_token is not None:
response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry.num_joined_members,
last_room_id=final_entry.room_id,
direction_is_forward=True,
last_module_index=room_ids_to_module_index.get(
final_entry.room_id
),
).to_token()
if more_to_come:
@ -259,6 +305,9 @@ class RoomListHandler:
last_joined_members=initial_entry.num_joined_members,
last_room_id=initial_entry.room_id,
direction_is_forward=False,
last_module_index=room_ids_to_module_index.get(
initial_entry.room_id
),
).to_token()
response["chunk"] = [attr.asdict(r, filter=filter_none) for r in results]
@ -507,11 +556,13 @@ class RoomListNextBatch:
last_joined_members: int # The count to get rooms after/before
last_room_id: str # The room_id to get rooms after/before
direction_is_forward: bool # True if this is a next_batch, false if prev_batch
last_module_index: Optional[int] = None
KEY_DICT = {
"last_joined_members": "m",
"last_room_id": "r",
"direction_is_forward": "d",
"last_module_index": "i",
}
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@ -524,6 +575,7 @@ class RoomListNextBatch:
)
def to_token(self) -> str:
# print(self)
return encode_base64(
msgpack.dumps(
{self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()}

View File

@ -372,7 +372,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
limit: Optional[int],
bounds: Optional[Tuple[int, str]],
bounds: Tuple[Optional[int], Optional[str]],
forwards: bool,
ignore_non_federatable: bool = False,
) -> List[PublicRoom]:
@ -420,26 +420,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# Work out the bounds if we're given them, these bounds look slightly
# odd, but are designed to help query planner use indices by pulling
# out a common bound.
if bounds:
last_joined_members, last_room_id = bounds
if forwards:
where_clauses.append(
"""
joined_members <= ? AND (
joined_members < ? OR room_id < ?
)
"""
)
else:
where_clauses.append(
"""
joined_members >= ? AND (
joined_members > ? OR room_id > ?
)
"""
)
last_joined_members, last_room_id = bounds
if last_joined_members is not None:
comp = "<" if forwards else ">"
query_args += [last_joined_members, last_joined_members, last_room_id]
clause = f"joined_members {comp}= ? AND (joined_members {comp} ?"
query_args += [last_joined_members, last_joined_members]
if last_room_id is None:
clause += ")"
else:
clause += f"OR room_id {comp} ?)"
query_args.append(last_room_id)
where_clauses.append(clause)
if ignore_non_federatable:
where_clauses.append("is_federatable")

View File

@ -49,106 +49,218 @@ class FetchPublicRoomsTestCase(HomeserverTestCase):
network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
limit: Optional[int],
bounds: Optional[Tuple[int, str]],
bounds: Tuple[Optional[int], Optional[str]],
forwards: bool,
) -> List[PublicRoom]:
room1 = PublicRoom(
room_id="!test1:test",
room_id="!one_members:module1",
num_joined_members=1,
world_readable=True,
guest_can_join=False,
)
room3 = PublicRoom(
room_id="!test3:test",
room_id="!three_members:module1",
num_joined_members=3,
world_readable=True,
guest_can_join=False,
)
room3_2 = PublicRoom(
room_id="!test3_2:test",
room_id="!three_members_2:module1",
num_joined_members=3,
world_readable=True,
guest_can_join=False,
)
if forwards:
if limit == 2:
if bounds is None:
return [room3_2, room3]
(last_joined_members, last_room_id) = bounds
if last_joined_members == 3:
if last_room_id == room3_2.room_id:
return [room3, room1]
if last_room_id == room3.room_id:
return [room1]
elif last_joined_members < 3:
return [room1]
return [room3_2, room3, room1]
else:
if limit == 2 and bounds is not None:
(last_joined_members, last_room_id) = bounds
if last_joined_members == 3:
if last_room_id == room3.room_id:
return [room3_2]
return [room1, room3, room3_2]
(last_joined_members, last_room_id) = bounds
print(f"cb {forwards} {bounds}")
result = [room1, room3, room3_2]
if last_joined_members is not None:
if forwards:
result = list(
filter(
lambda r: r.num_joined_members <= last_joined_members,
result,
)
)
else:
result = list(
filter(
lambda r: r.num_joined_members >= last_joined_members,
result,
)
)
print([r.room_id for r in result])
if last_room_id is not None:
new_res = []
for r in result:
if r.room_id == last_room_id:
break
new_res.append(r)
result = new_res
if forwards:
result.reverse()
if limit is not None:
result = result[:limit]
return result
# if forwards:
# if limit == 2:
# if last_joined_members is None:
# return [room3_2, room3]
# elif last_joined_members == 3:
# if last_room_id == room3_2.room_id:
# return [room3, room1]
# if last_room_id == room3.room_id:
# return [room1]
# elif last_joined_members < 3:
# return [room1]
# return [room3_2, room3, room1]
# else:
# if (
# limit == 2
# and last_joined_members == 3
# and last_room_id == room3.room_id
# ):
# return [room3_2]
# return [room1, room3, room3_2]
async def cb2(
network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
limit: Optional[int],
bounds: Tuple[Optional[int], Optional[str]],
forwards: bool,
) -> List[PublicRoom]:
room3 = PublicRoom(
room_id="!three_members:module2",
num_joined_members=3,
world_readable=True,
guest_can_join=False,
)
result = [room3]
(last_joined_members, last_room_id) = bounds
print(f"cb2 {forwards} {bounds}")
if last_joined_members is not None:
if forwards:
result = list(
filter(
lambda r: r.num_joined_members <= last_joined_members,
result,
)
)
else:
result = list(
filter(
lambda r: r.num_joined_members >= last_joined_members,
result,
)
)
print([r.room_id for r in result])
if last_room_id is not None:
new_res = []
for r in result:
if r.room_id == last_room_id:
break
new_res.append(r)
result = new_res
if forwards:
result.reverse()
if limit is not None:
result = result[:limit]
return result
self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb2)
self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb)
user = self.register_user("alice", "pass")
token = self.login(user, "pass")
# Create a room
user2 = self.register_user("alice2", "pass")
token2 = self.login(user2, "pass")
user3 = self.register_user("alice3", "pass")
token3 = self.login(user3, "pass")
# Create a room with 2 people
room_id = self.helper.create_room_as(
user,
is_public=True,
extra_content={"visibility": "public"},
tok=token,
)
user2 = self.register_user("alice2", "pass")
token2 = self.login(user2, "pass")
self.helper.join(room_id, user2, tok=token2)
# Create a room with 3 people
room_id = self.helper.create_room_as(
user,
is_public=True,
extra_content={"visibility": "public"},
tok=token,
)
self.helper.join(room_id, user2, tok=token2)
self.helper.join(room_id, user3, tok=token3)
def test_no_limit(self) -> None:
channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
chunk = channel.json_body["chunk"]
self.assertEquals(len(channel.json_body["chunk"]), 4)
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
self.assertEquals(channel.json_body["chunk"][1]["num_joined_members"], 3)
self.assertEquals(channel.json_body["chunk"][2]["num_joined_members"], 2)
self.assertEquals(channel.json_body["chunk"][3]["num_joined_members"], 1)
self.assertEquals(len(chunk), 6)
for i in range(4):
self.assertEquals(chunk[i]["num_joined_members"], 3)
self.assertEquals(chunk[4]["num_joined_members"], 2)
self.assertEquals(chunk[5]["num_joined_members"], 1)
def test_pagination(self) -> None:
channel = self.make_request("GET", self.url + "?limit=1")
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
returned_room3_id = channel.json_body["chunk"][0]["room_id"]
returned_three_members_rooms = set()
next_batch = None
for i in range(4):
since_query_str = f"&since={next_batch}" if next_batch else ""
channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 3)
self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
returned_three_members_rooms.add(chunk[0]["room_id"])
next_batch = channel.json_body["next_batch"]
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 2)
next_batch = channel.json_body["next_batch"]
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
# We should get the other room with 3 users here
self.assertNotEquals(
returned_room3_id, channel.json_body["chunk"][0]["room_id"]
)
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 1)
prev_batch = channel.json_body["prev_batch"]
channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
self.assertEquals(returned_room3_id, channel.json_body["chunk"][0]["room_id"])
next_batch = channel.json_body["next_batch"]
# channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
# chunk = channel.json_body["chunk"]
# print(chunk)
# self.assertEquals(chunk[0]["num_joined_members"], 2)
# prev_batch = channel.json_body["prev_batch"]
# We went backwards once, so we should get same result as step 2
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
self.assertNotEquals(
returned_room3_id, channel.json_body["chunk"][0]["room_id"]
)
next_batch = channel.json_body["next_batch"]
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 2)
next_batch = channel.json_body["next_batch"]
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 1)
# returned_three_members_rooms = set()
# for i in range(4):
# channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
# chunk = channel.json_body["chunk"]
# self.assertEquals(chunk[0]["num_joined_members"], 3)
# self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
# returned_three_members_rooms.add(chunk[0]["room_id"])
# prev_batch = channel.json_body["prev_batch"]