Directly lookup local membership instead of getting all members in a room first (`get_users_in_room` mis-use) (#13608)
See https://github.com/matrix-org/synapse/pull/13575#discussion_r953023755pull/13629/head
							parent
							
								
									b93bd95e8a
								
							
						
					
					
						commit
						d58615c82c
					
				| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Refactor `get_users_in_room(room_id)` mis-use to lookup single local user with dedicated `check_local_user_in_room(...)` function.
 | 
			
		||||
| 
						 | 
				
			
			@ -151,7 +151,7 @@ class EventHandler:
 | 
			
		|||
        """Retrieve a single specified event.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            user: The user requesting the event
 | 
			
		||||
            user: The local user requesting the event
 | 
			
		||||
            room_id: The expected room id. We'll return None if the
 | 
			
		||||
                event's room does not match.
 | 
			
		||||
            event_id: The event ID to obtain.
 | 
			
		||||
| 
						 | 
				
			
			@ -173,8 +173,11 @@ class EventHandler:
 | 
			
		|||
        if not event:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        users = await self.store.get_users_in_room(event.room_id)
 | 
			
		||||
        is_peeking = user.to_string() not in users
 | 
			
		||||
        is_user_in_room = await self.store.check_local_user_in_room(
 | 
			
		||||
            user_id=user.to_string(), room_id=event.room_id
 | 
			
		||||
        )
 | 
			
		||||
        # The user is peeking if they aren't in the room already
 | 
			
		||||
        is_peeking = not is_user_in_room
 | 
			
		||||
 | 
			
		||||
        filtered = await filter_events_for_client(
 | 
			
		||||
            self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -761,8 +761,10 @@ class EventCreationHandler:
 | 
			
		|||
    async def _is_server_notices_room(self, room_id: str) -> bool:
 | 
			
		||||
        if self.config.servernotices.server_notices_mxid is None:
 | 
			
		||||
            return False
 | 
			
		||||
        user_ids = await self.store.get_users_in_room(room_id)
 | 
			
		||||
        return self.config.servernotices.server_notices_mxid in user_ids
 | 
			
		||||
        is_server_notices_room = await self.store.check_local_user_in_room(
 | 
			
		||||
            user_id=self.config.servernotices.server_notices_mxid, room_id=room_id
 | 
			
		||||
        )
 | 
			
		||||
        return is_server_notices_room
 | 
			
		||||
 | 
			
		||||
    async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
 | 
			
		||||
        """Check if a user has accepted the privacy policy
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1284,8 +1284,11 @@ class RoomContextHandler:
 | 
			
		|||
        before_limit = math.floor(limit / 2.0)
 | 
			
		||||
        after_limit = limit - before_limit
 | 
			
		||||
 | 
			
		||||
        users = await self.store.get_users_in_room(room_id)
 | 
			
		||||
        is_peeking = user.to_string() not in users
 | 
			
		||||
        is_user_in_room = await self.store.check_local_user_in_room(
 | 
			
		||||
            user_id=user.to_string(), room_id=room_id
 | 
			
		||||
        )
 | 
			
		||||
        # The user is peeking if they aren't in the room already
 | 
			
		||||
        is_peeking = not is_user_in_room
 | 
			
		||||
 | 
			
		||||
        async def filter_evts(events: List[EventBase]) -> List[EventBase]:
 | 
			
		||||
            if use_admin_priviledge:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1620,8 +1620,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 | 
			
		|||
    async def _is_server_notice_room(self, room_id: str) -> bool:
 | 
			
		||||
        if self._server_notices_mxid is None:
 | 
			
		||||
            return False
 | 
			
		||||
        user_ids = await self.store.get_users_in_room(room_id)
 | 
			
		||||
        return self._server_notices_mxid in user_ids
 | 
			
		||||
        is_server_notices_room = await self.store.check_local_user_in_room(
 | 
			
		||||
            user_id=self._server_notices_mxid, room_id=room_id
 | 
			
		||||
        )
 | 
			
		||||
        return is_server_notices_room
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RoomMemberMasterHandler(RoomMemberHandler):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -102,6 +102,10 @@ class ServerNoticesManager:
 | 
			
		|||
        Returns:
 | 
			
		||||
            The room's ID, or None if no room could be found.
 | 
			
		||||
        """
 | 
			
		||||
        # If there is no server notices MXID, then there is no server notices room
 | 
			
		||||
        if self.server_notices_mxid is None:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        rooms = await self._store.get_rooms_for_local_user_where_membership_is(
 | 
			
		||||
            user_id, [Membership.INVITE, Membership.JOIN]
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			@ -111,8 +115,10 @@ class ServerNoticesManager:
 | 
			
		|||
            # be joined. This is kinda deliberate, in that if somebody somehow
 | 
			
		||||
            # manages to invite the system user to a room, that doesn't make it
 | 
			
		||||
            # the server notices room.
 | 
			
		||||
            user_ids = await self._store.get_users_in_room(room.room_id)
 | 
			
		||||
            if len(user_ids) <= 2 and self.server_notices_mxid in user_ids:
 | 
			
		||||
            is_server_notices_room = await self._store.check_local_user_in_room(
 | 
			
		||||
                user_id=self.server_notices_mxid, room_id=room.room_id
 | 
			
		||||
            )
 | 
			
		||||
            if is_server_notices_room:
 | 
			
		||||
                # we found a room which our user shares with the system notice
 | 
			
		||||
                # user
 | 
			
		||||
                return room.room_id
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -534,6 +534,32 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 | 
			
		|||
            desc="get_local_users_in_room",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Check whether a given local user is currently joined to the given room.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A boolean indicating whether the user is currently joined to the room
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            Exeption when called with a non-local user to this homeserver
 | 
			
		||||
        """
 | 
			
		||||
        if not self.hs.is_mine_id(user_id):
 | 
			
		||||
            raise Exception(
 | 
			
		||||
                "Cannot call 'check_local_user_in_room' on "
 | 
			
		||||
                "non-local user %s" % (user_id,),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        (
 | 
			
		||||
            membership,
 | 
			
		||||
            member_event_id,
 | 
			
		||||
        ) = await self.get_local_current_membership_for_user_in_room(
 | 
			
		||||
            user_id=user_id,
 | 
			
		||||
            room_id=room_id,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return membership == Membership.JOIN
 | 
			
		||||
 | 
			
		||||
    async def get_local_current_membership_for_user_in_room(
 | 
			
		||||
        self, user_id: str, room_id: str
 | 
			
		||||
    ) -> Tuple[Optional[str], Optional[str]]:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -999,7 +999,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
 | 
			
		|||
                bundled_aggregations,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6)
 | 
			
		||||
        self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
 | 
			
		||||
 | 
			
		||||
    def test_annotation_to_annotation(self) -> None:
 | 
			
		||||
        """Any relation to an annotation should be ignored."""
 | 
			
		||||
| 
						 | 
				
			
			@ -1035,7 +1035,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
 | 
			
		|||
                bundled_aggregations,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
 | 
			
		||||
        self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
 | 
			
		||||
 | 
			
		||||
    def test_thread(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -1080,21 +1080,21 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
 | 
			
		|||
 | 
			
		||||
        # The "user" sent the root event and is making queries for the bundled
 | 
			
		||||
        # aggregations: they have participated.
 | 
			
		||||
        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
 | 
			
		||||
        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9)
 | 
			
		||||
        # The "user2" sent replies in the thread and is making queries for the
 | 
			
		||||
        # bundled aggregations: they have participated.
 | 
			
		||||
        #
 | 
			
		||||
        # Note that this re-uses some cached values, so the total number of
 | 
			
		||||
        # queries is much smaller.
 | 
			
		||||
        self._test_bundled_aggregations(
 | 
			
		||||
            RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
 | 
			
		||||
            RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # A user with no interactions with the thread: they have not participated.
 | 
			
		||||
        user3_id, user3_token = self._create_user("charlie")
 | 
			
		||||
        self.helper.join(self.room, user=user3_id, tok=user3_token)
 | 
			
		||||
        self._test_bundled_aggregations(
 | 
			
		||||
            RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
 | 
			
		||||
            RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_thread_with_bundled_aggregations_for_latest(self) -> None:
 | 
			
		||||
| 
						 | 
				
			
			@ -1142,7 +1142,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
 | 
			
		|||
                bundled_aggregations["latest_event"].get("unsigned"),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
 | 
			
		||||
        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
 | 
			
		||||
 | 
			
		||||
    def test_nested_thread(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue