diff --git a/changelog.d/12836.misc b/changelog.d/12836.misc new file mode 100644 index 0000000000..85909c6a2d --- /dev/null +++ b/changelog.d/12836.misc @@ -0,0 +1 @@ +Remove Mutual Rooms ([MSC2666](https://github.com/matrix-org/matrix-spec-proposals/pull/2666)) endpoint dependency on the User Directory. \ No newline at end of file diff --git a/synapse/rest/client/mutual_rooms.py b/synapse/rest/client/mutual_rooms.py index 27bfaf0b29..38ef4e459f 100644 --- a/synapse/rest/client/mutual_rooms.py +++ b/synapse/rest/client/mutual_rooms.py @@ -42,21 +42,10 @@ class UserMutualRoomsServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.user_directory_search_enabled = ( - hs.config.userdirectory.user_directory_search_enabled - ) async def on_GET( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: - - if not self.user_directory_search_enabled: - raise SynapseError( - code=400, - msg="User directory searching is disabled. Cannot determine shared rooms.", - errcode=Codes.UNKNOWN, - ) - UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) @@ -67,8 +56,8 @@ class UserMutualRoomsServlet(RestServlet): errcode=Codes.FORBIDDEN, ) - rooms = await self.store.get_mutual_rooms_for_users( - requester.user.to_string(), user_id + rooms = await self.store.get_mutual_rooms_between_users( + frozenset((requester.user.to_string(), user_id)) ) return 200, {"joined": list(rooms)} diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index cc528fcf2d..e222b7bd1f 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -670,6 +670,30 @@ class RoomMemberWorkerStore(EventsWorkerStore): return user_who_share_room + @cached(cache_context=True, iterable=True) + async def get_mutual_rooms_between_users( + self, user_ids: FrozenSet[str], cache_context: _CacheContext + ) -> FrozenSet[str]: + """ + Returns the set of rooms that all users in `user_ids` share. + + Args: + user_ids: A frozen set of all users to investigate and return + overlapping joined rooms for. + cache_context + """ + shared_room_ids: Optional[FrozenSet[str]] = None + for user_id in user_ids: + room_ids = await self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate + ) + if shared_room_ids is not None: + shared_room_ids &= room_ids + else: + shared_room_ids = room_ids + + return shared_room_ids or frozenset() + async def get_joined_users_from_context( self, event: EventBase, context: EventContext ) -> Dict[str, ProfileInfo]: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 028db69af3..2282242e9d 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -729,49 +729,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - async def get_mutual_rooms_for_users( - self, user_id: str, other_user_id: str - ) -> Set[str]: - """ - Returns the rooms that a local user shares with another local or remote user. - - Args: - user_id: The MXID of a local user - other_user_id: The MXID of the other user - - Returns: - A set of room ID's that the users share. - """ - - def _get_mutual_rooms_for_users_txn( - txn: LoggingTransaction, - ) -> List[Dict[str, str]]: - txn.execute( - """ - SELECT p1.room_id - FROM users_in_public_rooms as p1 - INNER JOIN users_in_public_rooms as p2 - ON p1.room_id = p2.room_id - AND p1.user_id = ? - AND p2.user_id = ? - UNION - SELECT room_id - FROM users_who_share_private_rooms - WHERE - user_id = ? - AND other_user_id = ? - """, - (user_id, other_user_id, user_id, other_user_id), - ) - rows = self.db_pool.cursor_to_dict(txn) - return rows - - rows = await self.db_pool.runInteraction( - "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn - ) - - return {row["room_id"] for row in rows} - async def get_user_directory_stream_pos(self) -> Optional[int]: """ Get the stream ID of the user directory stream. diff --git a/tests/rest/client/test_mutual_rooms.py b/tests/rest/client/test_mutual_rooms.py index 7b7d283bb6..a4327f7ace 100644 --- a/tests/rest/client/test_mutual_rooms.py +++ b/tests/rest/client/test_mutual_rooms.py @@ -36,12 +36,10 @@ class UserMutualRoomsTest(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["update_user_directory"] = True return self.setup_test_homeserver(config=config) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - self.handler = hs.get_user_directory_handler() def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel: return self.make_request(