More tests, less bugs

anoa/public_rooms_module_api
Mathieu Velten 2023-05-30 17:41:43 +02:00
parent 3194933c1b
commit dcc49cd1ae
3 changed files with 113 additions and 119 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Optional, Tuple from typing import TYPE_CHECKING, Any, List, Optional, Tuple
import attr import attr
import msgpack import msgpack
@ -170,12 +170,13 @@ class RoomListHandler:
# we request one more than wanted to see if there are more pages to come # we request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None probing_limit = limit + 1 if limit is not None else None
results = [] results: List[PublicRoom] = []
print(f"last_module_index {last_module_index}") # print(f"{forwards} {last_joined_members} {last_room_id} {last_module_index}")
print(f"last_room_id {last_room_id}")
def insert_into_result(new_room: PublicRoom, module_index: Optional[int]): def insert_into_result(
new_room: PublicRoom, module_index: Optional[int]
) -> None:
# print(f"insert {new_room.room_id} {module_index}") # print(f"insert {new_room.room_id} {module_index}")
if new_room.num_joined_members == last_joined_members: if new_room.num_joined_members == last_joined_members:
if last_module_index is not None and last_room_id is not None: if last_module_index is not None and last_room_id is not None:
@ -221,8 +222,6 @@ class RoomListHandler:
forwards, forwards,
) )
print([r.room_id for r in module_public_rooms])
# We reverse for iteration to keep the order in the final list # We reverse for iteration to keep the order in the final list
# since we preprend when inserting # since we preprend when inserting
module_public_rooms.reverse() module_public_rooms.reverse()
@ -238,7 +237,7 @@ class RoomListHandler:
probing_limit, probing_limit,
bounds=( bounds=(
last_joined_members, last_joined_members,
last_room_id if last_module_index == None else None, last_room_id if last_module_index is None else None,
), ),
forwards=forwards, forwards=forwards,
ignore_non_federatable=bool(from_remote_server_name), ignore_non_federatable=bool(from_remote_server_name),
@ -247,22 +246,20 @@ class RoomListHandler:
for r in local_public_rooms: for r in local_public_rooms:
insert_into_result(r, None) insert_into_result(r, None)
# print("final")
# print([r.room_id for r in results])
response: JsonDict = {} response: JsonDict = {}
num_results = len(results) num_results = len(results)
if limit is not None and probing_limit is not None: if limit is not None and probing_limit is not None:
more_to_come = num_results >= probing_limit more_to_come = num_results >= probing_limit
# Depending on direction we trim either the front or back. results = results[:limit]
if forwards:
results = results[:limit]
else:
results = results[-limit:]
else: else:
more_to_come = False more_to_come = False
if not forwards:
results.reverse()
# print("final ", [(r.room_id, r.num_joined_members) for r in results])
if num_results > 0: if num_results > 0:
final_entry = results[-1] final_entry = results[-1]
initial_entry = results[0] initial_entry = results[0]

View File

@ -26,7 +26,7 @@ FETCH_PUBLIC_ROOMS_CALLBACK = Callable[
Optional[ThirdPartyInstanceID], # network_tuple Optional[ThirdPartyInstanceID], # network_tuple
Optional[dict], # search_filter Optional[dict], # search_filter
Optional[int], # limit Optional[int], # limit
Optional[Tuple[int, str]], # bounds Tuple[Optional[int], Optional[str]], # bounds
bool, # forwards bool, # forwards
], ],
Awaitable[List[PublicRoom]], Awaitable[List[PublicRoom]],

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -45,7 +44,7 @@ class FetchPublicRoomsTestCase(HomeserverTestCase):
self._store = homeserver.get_datastores().main self._store = homeserver.get_datastores().main
self._module_api = homeserver.get_module_api() self._module_api = homeserver.get_module_api()
async def cb( async def module1_cb(
network_tuple: Optional[ThirdPartyInstanceID], network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict], search_filter: Optional[dict],
limit: Optional[int], limit: Optional[int],
@ -73,66 +72,48 @@ class FetchPublicRoomsTestCase(HomeserverTestCase):
(last_joined_members, last_room_id) = bounds (last_joined_members, last_room_id) = bounds
print(f"cb {forwards} {bounds}") if forwards:
result = [room3_2, room3, room1]
result = [room1, room3, room3_2] else:
result = [room1, room3, room3_2]
if last_joined_members is not None: if last_joined_members is not None:
if forwards: if last_joined_members == 1:
result = list( if forwards:
filter( if last_room_id == room1.room_id:
lambda r: r.num_joined_members <= last_joined_members, result = []
result, else:
) result = [room1]
) else:
else: if last_room_id == room1.room_id:
result = list( result = [room3, room3_2]
filter( else:
lambda r: r.num_joined_members >= last_joined_members, result = [room1, room3, room3_2]
result, elif last_joined_members == 2:
) if forwards:
) result = [room1]
else:
print([r.room_id for r in result]) result = [room3, room3_2]
elif last_joined_members == 3:
if last_room_id is not None: if forwards:
new_res = [] if last_room_id == room3.room_id:
for r in result: result = [room1]
if r.room_id == last_room_id: elif last_room_id == room3_2.room_id:
break result = [room3, room1]
new_res.append(r) else:
result = new_res if last_room_id == room3.room_id:
result = [room3_2]
if forwards: elif last_room_id == room3_2.room_id:
result.reverse() result = []
else:
result = [room3, room3_2]
if limit is not None: if limit is not None:
result = result[:limit] result = result[:limit]
return result return result
# if forwards: async def module2_cb(
# if limit == 2:
# if last_joined_members is None:
# return [room3_2, room3]
# elif last_joined_members == 3:
# if last_room_id == room3_2.room_id:
# return [room3, room1]
# if last_room_id == room3.room_id:
# return [room1]
# elif last_joined_members < 3:
# return [room1]
# return [room3_2, room3, room1]
# else:
# if (
# limit == 2
# and last_joined_members == 3
# and last_room_id == room3.room_id
# ):
# return [room3_2]
# return [room1, room3, room3_2]
async def cb2(
network_tuple: Optional[ThirdPartyInstanceID], network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict], search_filter: Optional[dict],
limit: Optional[int], limit: Optional[int],
@ -146,48 +127,26 @@ class FetchPublicRoomsTestCase(HomeserverTestCase):
guest_can_join=False, guest_can_join=False,
) )
result = [room3]
(last_joined_members, last_room_id) = bounds (last_joined_members, last_room_id) = bounds
print(f"cb2 {forwards} {bounds}") result = [room3]
if last_joined_members is not None: if last_joined_members is not None:
if forwards: if forwards:
result = list( if last_joined_members < 3:
filter( result = []
lambda r: r.num_joined_members <= last_joined_members, elif last_joined_members == 3 and last_room_id == room3.room_id:
result, result = []
)
)
else: else:
result = list( if last_joined_members > 3:
filter( result = []
lambda r: r.num_joined_members >= last_joined_members, elif last_joined_members == 3 and last_room_id == room3.room_id:
result, result = []
)
)
print([r.room_id for r in result])
if last_room_id is not None:
new_res = []
for r in result:
if r.room_id == last_room_id:
break
new_res.append(r)
result = new_res
if forwards:
result.reverse()
if limit is not None:
result = result[:limit]
return result return result
self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb2) self._module_api.register_public_rooms_callbacks(fetch_public_rooms=module1_cb)
self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb) self._module_api.register_public_rooms_callbacks(fetch_public_rooms=module2_cb)
user = self.register_user("alice", "pass") user = self.register_user("alice", "pass")
token = self.login(user, "pass") token = self.login(user, "pass")
@ -227,11 +186,11 @@ class FetchPublicRoomsTestCase(HomeserverTestCase):
self.assertEquals(chunk[4]["num_joined_members"], 2) self.assertEquals(chunk[4]["num_joined_members"], 2)
self.assertEquals(chunk[5]["num_joined_members"], 1) self.assertEquals(chunk[5]["num_joined_members"], 1)
def test_pagination(self) -> None: def test_pagination_limit_1(self) -> None:
returned_three_members_rooms = set() returned_three_members_rooms = set()
next_batch = None next_batch = None
for i in range(4): for _i in range(4):
since_query_str = f"&since={next_batch}" if next_batch else "" since_query_str = f"&since={next_batch}" if next_batch else ""
channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}") channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}")
chunk = channel.json_body["chunk"] chunk = channel.json_body["chunk"]
@ -250,17 +209,55 @@ class FetchPublicRoomsTestCase(HomeserverTestCase):
self.assertEquals(chunk[0]["num_joined_members"], 1) self.assertEquals(chunk[0]["num_joined_members"], 1)
prev_batch = channel.json_body["prev_batch"] prev_batch = channel.json_body["prev_batch"]
# channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}") self.assertNotIn("next_batch", channel.json_body)
# chunk = channel.json_body["chunk"]
# print(chunk)
# self.assertEquals(chunk[0]["num_joined_members"], 2)
# prev_batch = channel.json_body["prev_batch"]
# returned_three_members_rooms = set() channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
# for i in range(4): chunk = channel.json_body["chunk"]
# channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}") self.assertEquals(chunk[0]["num_joined_members"], 2)
# chunk = channel.json_body["chunk"]
# self.assertEquals(chunk[0]["num_joined_members"], 3) returned_three_members_rooms = set()
# self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms) for _i in range(4):
# returned_three_members_rooms.add(chunk[0]["room_id"]) prev_batch = channel.json_body["prev_batch"]
# prev_batch = channel.json_body["prev_batch"] channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 3)
self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
returned_three_members_rooms.add(chunk[0]["room_id"])
self.assertNotIn("prev_batch", channel.json_body)
def test_pagination_limit_2(self) -> None:
returned_three_members_rooms = set()
next_batch = None
for _i in range(2):
since_query_str = f"&since={next_batch}" if next_batch else ""
channel = self.make_request("GET", f"{self.url}?limit=2{since_query_str}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 3)
self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
returned_three_members_rooms.add(chunk[0]["room_id"])
self.assertTrue(chunk[1]["room_id"] not in returned_three_members_rooms)
returned_three_members_rooms.add(chunk[1]["room_id"])
next_batch = channel.json_body["next_batch"]
channel = self.make_request("GET", f"{self.url}?limit=2&since={next_batch}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 2)
self.assertEquals(chunk[1]["num_joined_members"], 1)
self.assertNotIn("next_batch", channel.json_body)
returned_three_members_rooms = set()
for _i in range(2):
prev_batch = channel.json_body["prev_batch"]
channel = self.make_request("GET", f"{self.url}?limit=2&since={prev_batch}")
chunk = channel.json_body["chunk"]
self.assertEquals(chunk[0]["num_joined_members"], 3)
self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
returned_three_members_rooms.add(chunk[0]["room_id"])
self.assertTrue(chunk[1]["room_id"] not in returned_three_members_rooms)
returned_three_members_rooms.add(chunk[1]["room_id"])
self.assertNotIn("prev_batch", channel.json_body)