Fix forwards pagination
parent
50d75a311b
commit
3194933c1b
|
@ -33,7 +33,7 @@ from synapse.api.errors import (
|
|||
RequestSendFailed,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.types import JsonDict, ThirdPartyInstanceID
|
||||
from synapse.types import JsonDict, PublicRoom, ThirdPartyInstanceID
|
||||
from synapse.util import filter_none
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
@ -156,62 +156,99 @@ class RoomListHandler:
|
|||
# plus the direction (forwards or backwards). Next batch tokens always
|
||||
# go forwards, prev batch tokens always go backwards.
|
||||
|
||||
forwards = True
|
||||
last_joined_members = None
|
||||
last_room_id = None
|
||||
last_module_index = None
|
||||
if since_token:
|
||||
batch_token = RoomListNextBatch.from_token(since_token)
|
||||
forwards = batch_token.direction_is_forward
|
||||
else:
|
||||
batch_token = None
|
||||
forwards = True
|
||||
last_joined_members = batch_token.last_joined_members
|
||||
last_room_id = batch_token.last_room_id
|
||||
last_module_index = batch_token.last_module_index
|
||||
|
||||
# we request one more than wanted to see if there are more pages to come
|
||||
probing_limit = limit + 1 if limit is not None else None
|
||||
|
||||
results = await self.store.get_largest_public_rooms(
|
||||
network_tuple,
|
||||
search_filter,
|
||||
probing_limit,
|
||||
bounds=(batch_token.last_joined_members, batch_token.last_room_id)
|
||||
if batch_token
|
||||
else None,
|
||||
forwards=forwards,
|
||||
ignore_non_federatable=bool(from_remote_server_name),
|
||||
)
|
||||
results = []
|
||||
|
||||
for (
|
||||
fetch_public_rooms
|
||||
) in self._module_api_callbacks.fetch_public_rooms_callbacks:
|
||||
print(f"last_module_index {last_module_index}")
|
||||
print(f"last_room_id {last_room_id}")
|
||||
|
||||
def insert_into_result(new_room: PublicRoom, module_index: Optional[int]):
|
||||
# print(f"insert {new_room.room_id} {module_index}")
|
||||
if new_room.num_joined_members == last_joined_members:
|
||||
if last_module_index is not None and last_room_id is not None:
|
||||
if module_index is not None and module_index > last_module_index:
|
||||
return
|
||||
inserted = False
|
||||
for i, r in enumerate(results):
|
||||
r = results[i]
|
||||
if (
|
||||
forwards and new_room.num_joined_members >= r.num_joined_members
|
||||
) or (
|
||||
not forwards and new_room.num_joined_members <= r.num_joined_members
|
||||
):
|
||||
results.insert(i, new_room)
|
||||
inserted = True
|
||||
return
|
||||
if not inserted:
|
||||
if forwards:
|
||||
results.append(new_room)
|
||||
else:
|
||||
results.insert(0, new_room)
|
||||
|
||||
room_ids_to_module_index = {}
|
||||
|
||||
for module_index, fetch_public_rooms in enumerate(
|
||||
self._module_api_callbacks.fetch_public_rooms_callbacks
|
||||
):
|
||||
# Ask each module for a list of public rooms given the last_joined_members
|
||||
# value from the since token and the probing limit.
|
||||
module_last_joined_members = None
|
||||
if last_joined_members is not None:
|
||||
module_last_joined_members = last_joined_members
|
||||
if last_module_index is not None and last_module_index < module_index:
|
||||
module_last_joined_members = module_last_joined_members - 1
|
||||
module_public_rooms = await fetch_public_rooms(
|
||||
network_tuple,
|
||||
search_filter,
|
||||
probing_limit,
|
||||
(batch_token.last_joined_members, batch_token.last_room_id)
|
||||
if batch_token
|
||||
else None,
|
||||
(
|
||||
module_last_joined_members,
|
||||
last_room_id if last_module_index == module_index else None,
|
||||
),
|
||||
forwards,
|
||||
)
|
||||
|
||||
print([r.room_id for r in module_public_rooms])
|
||||
|
||||
# We reverse for iteration to keep the order in the final list
|
||||
# since we preprend when inserting
|
||||
module_public_rooms.reverse()
|
||||
|
||||
# Insert the module's reported public rooms into the list
|
||||
for new_room in module_public_rooms:
|
||||
inserted = False
|
||||
for i in range(len(results)):
|
||||
r = results[i]
|
||||
if (
|
||||
forwards and new_room.num_joined_members >= r.num_joined_members
|
||||
) or (
|
||||
not forwards
|
||||
and new_room.num_joined_members <= r.num_joined_members
|
||||
):
|
||||
results.insert(i, new_room)
|
||||
inserted = True
|
||||
break
|
||||
if not inserted:
|
||||
if forwards:
|
||||
results.append(new_room)
|
||||
else:
|
||||
results.insert(0, new_room)
|
||||
room_ids_to_module_index[new_room.room_id] = module_index
|
||||
insert_into_result(new_room, module_index)
|
||||
|
||||
local_public_rooms = await self.store.get_largest_public_rooms(
|
||||
network_tuple,
|
||||
search_filter,
|
||||
probing_limit,
|
||||
bounds=(
|
||||
last_joined_members,
|
||||
last_room_id if last_module_index == None else None,
|
||||
),
|
||||
forwards=forwards,
|
||||
ignore_non_federatable=bool(from_remote_server_name),
|
||||
)
|
||||
|
||||
for r in local_public_rooms:
|
||||
insert_into_result(r, None)
|
||||
|
||||
# print("final")
|
||||
# print([r.room_id for r in results])
|
||||
|
||||
response: JsonDict = {}
|
||||
num_results = len(results)
|
||||
|
@ -231,13 +268,16 @@ class RoomListHandler:
|
|||
initial_entry = results[0]
|
||||
|
||||
if forwards:
|
||||
if batch_token is not None:
|
||||
if since_token is not None:
|
||||
# If there was a token given then we assume that there
|
||||
# must be previous results.
|
||||
response["prev_batch"] = RoomListNextBatch(
|
||||
last_joined_members=initial_entry.num_joined_members,
|
||||
last_room_id=initial_entry.room_id,
|
||||
direction_is_forward=False,
|
||||
last_module_index=room_ids_to_module_index.get(
|
||||
initial_entry.room_id
|
||||
),
|
||||
).to_token()
|
||||
|
||||
if more_to_come:
|
||||
|
@ -245,13 +285,19 @@ class RoomListHandler:
|
|||
last_joined_members=final_entry.num_joined_members,
|
||||
last_room_id=final_entry.room_id,
|
||||
direction_is_forward=True,
|
||||
last_module_index=room_ids_to_module_index.get(
|
||||
final_entry.room_id
|
||||
),
|
||||
).to_token()
|
||||
else:
|
||||
if batch_token is not None:
|
||||
if since_token is not None:
|
||||
response["next_batch"] = RoomListNextBatch(
|
||||
last_joined_members=final_entry.num_joined_members,
|
||||
last_room_id=final_entry.room_id,
|
||||
direction_is_forward=True,
|
||||
last_module_index=room_ids_to_module_index.get(
|
||||
final_entry.room_id
|
||||
),
|
||||
).to_token()
|
||||
|
||||
if more_to_come:
|
||||
|
@ -259,6 +305,9 @@ class RoomListHandler:
|
|||
last_joined_members=initial_entry.num_joined_members,
|
||||
last_room_id=initial_entry.room_id,
|
||||
direction_is_forward=False,
|
||||
last_module_index=room_ids_to_module_index.get(
|
||||
initial_entry.room_id
|
||||
),
|
||||
).to_token()
|
||||
|
||||
response["chunk"] = [attr.asdict(r, filter=filter_none) for r in results]
|
||||
|
@ -507,11 +556,13 @@ class RoomListNextBatch:
|
|||
last_joined_members: int # The count to get rooms after/before
|
||||
last_room_id: str # The room_id to get rooms after/before
|
||||
direction_is_forward: bool # True if this is a next_batch, false if prev_batch
|
||||
last_module_index: Optional[int] = None
|
||||
|
||||
KEY_DICT = {
|
||||
"last_joined_members": "m",
|
||||
"last_room_id": "r",
|
||||
"direction_is_forward": "d",
|
||||
"last_module_index": "i",
|
||||
}
|
||||
|
||||
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
|
||||
|
@ -524,6 +575,7 @@ class RoomListNextBatch:
|
|||
)
|
||||
|
||||
def to_token(self) -> str:
|
||||
# print(self)
|
||||
return encode_base64(
|
||||
msgpack.dumps(
|
||||
{self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()}
|
||||
|
|
|
@ -372,7 +372,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
network_tuple: Optional[ThirdPartyInstanceID],
|
||||
search_filter: Optional[dict],
|
||||
limit: Optional[int],
|
||||
bounds: Optional[Tuple[int, str]],
|
||||
bounds: Tuple[Optional[int], Optional[str]],
|
||||
forwards: bool,
|
||||
ignore_non_federatable: bool = False,
|
||||
) -> List[PublicRoom]:
|
||||
|
@ -420,26 +420,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
# Work out the bounds if we're given them, these bounds look slightly
|
||||
# odd, but are designed to help query planner use indices by pulling
|
||||
# out a common bound.
|
||||
if bounds:
|
||||
last_joined_members, last_room_id = bounds
|
||||
if forwards:
|
||||
where_clauses.append(
|
||||
"""
|
||||
joined_members <= ? AND (
|
||||
joined_members < ? OR room_id < ?
|
||||
)
|
||||
"""
|
||||
)
|
||||
else:
|
||||
where_clauses.append(
|
||||
"""
|
||||
joined_members >= ? AND (
|
||||
joined_members > ? OR room_id > ?
|
||||
)
|
||||
"""
|
||||
)
|
||||
last_joined_members, last_room_id = bounds
|
||||
if last_joined_members is not None:
|
||||
comp = "<" if forwards else ">"
|
||||
|
||||
query_args += [last_joined_members, last_joined_members, last_room_id]
|
||||
clause = f"joined_members {comp}= ? AND (joined_members {comp} ?"
|
||||
query_args += [last_joined_members, last_joined_members]
|
||||
|
||||
if last_room_id is None:
|
||||
clause += ")"
|
||||
else:
|
||||
clause += f"OR room_id {comp} ?)"
|
||||
query_args.append(last_room_id)
|
||||
|
||||
where_clauses.append(clause)
|
||||
|
||||
if ignore_non_federatable:
|
||||
where_clauses.append("is_federatable")
|
||||
|
|
|
@ -49,106 +49,218 @@ class FetchPublicRoomsTestCase(HomeserverTestCase):
|
|||
network_tuple: Optional[ThirdPartyInstanceID],
|
||||
search_filter: Optional[dict],
|
||||
limit: Optional[int],
|
||||
bounds: Optional[Tuple[int, str]],
|
||||
bounds: Tuple[Optional[int], Optional[str]],
|
||||
forwards: bool,
|
||||
) -> List[PublicRoom]:
|
||||
room1 = PublicRoom(
|
||||
room_id="!test1:test",
|
||||
room_id="!one_members:module1",
|
||||
num_joined_members=1,
|
||||
world_readable=True,
|
||||
guest_can_join=False,
|
||||
)
|
||||
room3 = PublicRoom(
|
||||
room_id="!test3:test",
|
||||
room_id="!three_members:module1",
|
||||
num_joined_members=3,
|
||||
world_readable=True,
|
||||
guest_can_join=False,
|
||||
)
|
||||
room3_2 = PublicRoom(
|
||||
room_id="!test3_2:test",
|
||||
room_id="!three_members_2:module1",
|
||||
num_joined_members=3,
|
||||
world_readable=True,
|
||||
guest_can_join=False,
|
||||
)
|
||||
|
||||
if forwards:
|
||||
if limit == 2:
|
||||
if bounds is None:
|
||||
return [room3_2, room3]
|
||||
(last_joined_members, last_room_id) = bounds
|
||||
if last_joined_members == 3:
|
||||
if last_room_id == room3_2.room_id:
|
||||
return [room3, room1]
|
||||
if last_room_id == room3.room_id:
|
||||
return [room1]
|
||||
elif last_joined_members < 3:
|
||||
return [room1]
|
||||
return [room3_2, room3, room1]
|
||||
else:
|
||||
if limit == 2 and bounds is not None:
|
||||
(last_joined_members, last_room_id) = bounds
|
||||
if last_joined_members == 3:
|
||||
if last_room_id == room3.room_id:
|
||||
return [room3_2]
|
||||
return [room1, room3, room3_2]
|
||||
(last_joined_members, last_room_id) = bounds
|
||||
|
||||
print(f"cb {forwards} {bounds}")
|
||||
|
||||
result = [room1, room3, room3_2]
|
||||
|
||||
if last_joined_members is not None:
|
||||
if forwards:
|
||||
result = list(
|
||||
filter(
|
||||
lambda r: r.num_joined_members <= last_joined_members,
|
||||
result,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result = list(
|
||||
filter(
|
||||
lambda r: r.num_joined_members >= last_joined_members,
|
||||
result,
|
||||
)
|
||||
)
|
||||
|
||||
print([r.room_id for r in result])
|
||||
|
||||
if last_room_id is not None:
|
||||
new_res = []
|
||||
for r in result:
|
||||
if r.room_id == last_room_id:
|
||||
break
|
||||
new_res.append(r)
|
||||
result = new_res
|
||||
|
||||
if forwards:
|
||||
result.reverse()
|
||||
|
||||
if limit is not None:
|
||||
result = result[:limit]
|
||||
|
||||
return result
|
||||
|
||||
# if forwards:
|
||||
# if limit == 2:
|
||||
# if last_joined_members is None:
|
||||
# return [room3_2, room3]
|
||||
# elif last_joined_members == 3:
|
||||
# if last_room_id == room3_2.room_id:
|
||||
# return [room3, room1]
|
||||
# if last_room_id == room3.room_id:
|
||||
# return [room1]
|
||||
# elif last_joined_members < 3:
|
||||
# return [room1]
|
||||
# return [room3_2, room3, room1]
|
||||
# else:
|
||||
# if (
|
||||
# limit == 2
|
||||
# and last_joined_members == 3
|
||||
# and last_room_id == room3.room_id
|
||||
# ):
|
||||
# return [room3_2]
|
||||
# return [room1, room3, room3_2]
|
||||
|
||||
async def cb2(
|
||||
network_tuple: Optional[ThirdPartyInstanceID],
|
||||
search_filter: Optional[dict],
|
||||
limit: Optional[int],
|
||||
bounds: Tuple[Optional[int], Optional[str]],
|
||||
forwards: bool,
|
||||
) -> List[PublicRoom]:
|
||||
room3 = PublicRoom(
|
||||
room_id="!three_members:module2",
|
||||
num_joined_members=3,
|
||||
world_readable=True,
|
||||
guest_can_join=False,
|
||||
)
|
||||
|
||||
result = [room3]
|
||||
|
||||
(last_joined_members, last_room_id) = bounds
|
||||
|
||||
print(f"cb2 {forwards} {bounds}")
|
||||
|
||||
if last_joined_members is not None:
|
||||
if forwards:
|
||||
result = list(
|
||||
filter(
|
||||
lambda r: r.num_joined_members <= last_joined_members,
|
||||
result,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result = list(
|
||||
filter(
|
||||
lambda r: r.num_joined_members >= last_joined_members,
|
||||
result,
|
||||
)
|
||||
)
|
||||
|
||||
print([r.room_id for r in result])
|
||||
|
||||
if last_room_id is not None:
|
||||
new_res = []
|
||||
for r in result:
|
||||
if r.room_id == last_room_id:
|
||||
break
|
||||
new_res.append(r)
|
||||
result = new_res
|
||||
|
||||
if forwards:
|
||||
result.reverse()
|
||||
|
||||
if limit is not None:
|
||||
result = result[:limit]
|
||||
|
||||
return result
|
||||
|
||||
self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb2)
|
||||
self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb)
|
||||
|
||||
user = self.register_user("alice", "pass")
|
||||
token = self.login(user, "pass")
|
||||
|
||||
# Create a room
|
||||
user2 = self.register_user("alice2", "pass")
|
||||
token2 = self.login(user2, "pass")
|
||||
|
||||
user3 = self.register_user("alice3", "pass")
|
||||
token3 = self.login(user3, "pass")
|
||||
|
||||
# Create a room with 2 people
|
||||
room_id = self.helper.create_room_as(
|
||||
user,
|
||||
is_public=True,
|
||||
extra_content={"visibility": "public"},
|
||||
tok=token,
|
||||
)
|
||||
|
||||
user2 = self.register_user("alice2", "pass")
|
||||
token2 = self.login(user2, "pass")
|
||||
self.helper.join(room_id, user2, tok=token2)
|
||||
|
||||
# Create a room with 3 people
|
||||
room_id = self.helper.create_room_as(
|
||||
user,
|
||||
is_public=True,
|
||||
extra_content={"visibility": "public"},
|
||||
tok=token,
|
||||
)
|
||||
self.helper.join(room_id, user2, tok=token2)
|
||||
self.helper.join(room_id, user3, tok=token3)
|
||||
|
||||
def test_no_limit(self) -> None:
|
||||
channel = self.make_request("GET", self.url)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
chunk = channel.json_body["chunk"]
|
||||
|
||||
self.assertEquals(len(channel.json_body["chunk"]), 4)
|
||||
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
|
||||
self.assertEquals(channel.json_body["chunk"][1]["num_joined_members"], 3)
|
||||
self.assertEquals(channel.json_body["chunk"][2]["num_joined_members"], 2)
|
||||
self.assertEquals(channel.json_body["chunk"][3]["num_joined_members"], 1)
|
||||
self.assertEquals(len(chunk), 6)
|
||||
for i in range(4):
|
||||
self.assertEquals(chunk[i]["num_joined_members"], 3)
|
||||
self.assertEquals(chunk[4]["num_joined_members"], 2)
|
||||
self.assertEquals(chunk[5]["num_joined_members"], 1)
|
||||
|
||||
def test_pagination(self) -> None:
|
||||
channel = self.make_request("GET", self.url + "?limit=1")
|
||||
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
|
||||
returned_room3_id = channel.json_body["chunk"][0]["room_id"]
|
||||
returned_three_members_rooms = set()
|
||||
|
||||
next_batch = None
|
||||
for i in range(4):
|
||||
since_query_str = f"&since={next_batch}" if next_batch else ""
|
||||
channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}")
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertEquals(chunk[0]["num_joined_members"], 3)
|
||||
self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
|
||||
returned_three_members_rooms.add(chunk[0]["room_id"])
|
||||
next_batch = channel.json_body["next_batch"]
|
||||
|
||||
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertEquals(chunk[0]["num_joined_members"], 2)
|
||||
next_batch = channel.json_body["next_batch"]
|
||||
|
||||
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
|
||||
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
|
||||
# We should get the other room with 3 users here
|
||||
self.assertNotEquals(
|
||||
returned_room3_id, channel.json_body["chunk"][0]["room_id"]
|
||||
)
|
||||
chunk = channel.json_body["chunk"]
|
||||
self.assertEquals(chunk[0]["num_joined_members"], 1)
|
||||
prev_batch = channel.json_body["prev_batch"]
|
||||
|
||||
channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
|
||||
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
|
||||
self.assertEquals(returned_room3_id, channel.json_body["chunk"][0]["room_id"])
|
||||
next_batch = channel.json_body["next_batch"]
|
||||
# channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
|
||||
# chunk = channel.json_body["chunk"]
|
||||
# print(chunk)
|
||||
# self.assertEquals(chunk[0]["num_joined_members"], 2)
|
||||
# prev_batch = channel.json_body["prev_batch"]
|
||||
|
||||
# We went backwards once, so we should get same result as step 2
|
||||
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
|
||||
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 3)
|
||||
self.assertNotEquals(
|
||||
returned_room3_id, channel.json_body["chunk"][0]["room_id"]
|
||||
)
|
||||
next_batch = channel.json_body["next_batch"]
|
||||
|
||||
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
|
||||
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 2)
|
||||
next_batch = channel.json_body["next_batch"]
|
||||
|
||||
channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
|
||||
self.assertEquals(channel.json_body["chunk"][0]["num_joined_members"], 1)
|
||||
# returned_three_members_rooms = set()
|
||||
# for i in range(4):
|
||||
# channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
|
||||
# chunk = channel.json_body["chunk"]
|
||||
# self.assertEquals(chunk[0]["num_joined_members"], 3)
|
||||
# self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
|
||||
# returned_three_members_rooms.add(chunk[0]["room_id"])
|
||||
# prev_batch = channel.json_body["prev_batch"]
|
||||
|
|
Loading…
Reference in New Issue