Directly lookup local membership instead of getting all members in a room first (`get_users_in_room` mis-use) (#13608)

See https://github.com/matrix-org/synapse/pull/13575#discussion_r953023755
pull/13629/head
Eric Eastwood 2022-08-24 14:13:12 -05:00 committed by GitHub
parent b93bd95e8a
commit d58615c82c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 60 additions and 17 deletions

1
changelog.d/13608.misc Normal file
View File

@ -0,0 +1 @@
Refactor `get_users_in_room(room_id)` mis-use to lookup single local user with dedicated `check_local_user_in_room(...)` function.

View File

@ -151,7 +151,7 @@ class EventHandler:
"""Retrieve a single specified event. """Retrieve a single specified event.
Args: Args:
user: The user requesting the event user: The local user requesting the event
room_id: The expected room id. We'll return None if the room_id: The expected room id. We'll return None if the
event's room does not match. event's room does not match.
event_id: The event ID to obtain. event_id: The event ID to obtain.
@ -173,8 +173,11 @@ class EventHandler:
if not event: if not event:
return None return None
users = await self.store.get_users_in_room(event.room_id) is_user_in_room = await self.store.check_local_user_in_room(
is_peeking = user.to_string() not in users user_id=user.to_string(), room_id=event.room_id
)
# The user is peeking if they aren't in the room already
is_peeking = not is_user_in_room
filtered = await filter_events_for_client( filtered = await filter_events_for_client(
self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking

View File

@ -761,8 +761,10 @@ class EventCreationHandler:
async def _is_server_notices_room(self, room_id: str) -> bool: async def _is_server_notices_room(self, room_id: str) -> bool:
if self.config.servernotices.server_notices_mxid is None: if self.config.servernotices.server_notices_mxid is None:
return False return False
user_ids = await self.store.get_users_in_room(room_id) is_server_notices_room = await self.store.check_local_user_in_room(
return self.config.servernotices.server_notices_mxid in user_ids user_id=self.config.servernotices.server_notices_mxid, room_id=room_id
)
return is_server_notices_room
async def assert_accepted_privacy_policy(self, requester: Requester) -> None: async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
"""Check if a user has accepted the privacy policy """Check if a user has accepted the privacy policy

View File

@ -1284,8 +1284,11 @@ class RoomContextHandler:
before_limit = math.floor(limit / 2.0) before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit after_limit = limit - before_limit
users = await self.store.get_users_in_room(room_id) is_user_in_room = await self.store.check_local_user_in_room(
is_peeking = user.to_string() not in users user_id=user.to_string(), room_id=room_id
)
# The user is peeking if they aren't in the room already
is_peeking = not is_user_in_room
async def filter_evts(events: List[EventBase]) -> List[EventBase]: async def filter_evts(events: List[EventBase]) -> List[EventBase]:
if use_admin_priviledge: if use_admin_priviledge:

View File

@ -1620,8 +1620,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
async def _is_server_notice_room(self, room_id: str) -> bool: async def _is_server_notice_room(self, room_id: str) -> bool:
if self._server_notices_mxid is None: if self._server_notices_mxid is None:
return False return False
user_ids = await self.store.get_users_in_room(room_id) is_server_notices_room = await self.store.check_local_user_in_room(
return self._server_notices_mxid in user_ids user_id=self._server_notices_mxid, room_id=room_id
)
return is_server_notices_room
class RoomMemberMasterHandler(RoomMemberHandler): class RoomMemberMasterHandler(RoomMemberHandler):

View File

@ -102,6 +102,10 @@ class ServerNoticesManager:
Returns: Returns:
The room's ID, or None if no room could be found. The room's ID, or None if no room could be found.
""" """
# If there is no server notices MXID, then there is no server notices room
if self.server_notices_mxid is None:
return None
rooms = await self._store.get_rooms_for_local_user_where_membership_is( rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN] user_id, [Membership.INVITE, Membership.JOIN]
) )
@ -111,8 +115,10 @@ class ServerNoticesManager:
# be joined. This is kinda deliberate, in that if somebody somehow # be joined. This is kinda deliberate, in that if somebody somehow
# manages to invite the system user to a room, that doesn't make it # manages to invite the system user to a room, that doesn't make it
# the server notices room. # the server notices room.
user_ids = await self._store.get_users_in_room(room.room_id) is_server_notices_room = await self._store.check_local_user_in_room(
if len(user_ids) <= 2 and self.server_notices_mxid in user_ids: user_id=self.server_notices_mxid, room_id=room.room_id
)
if is_server_notices_room:
# we found a room which our user shares with the system notice # we found a room which our user shares with the system notice
# user # user
return room.room_id return room.room_id

View File

@ -534,6 +534,32 @@ class RoomMemberWorkerStore(EventsWorkerStore):
desc="get_local_users_in_room", desc="get_local_users_in_room",
) )
async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
"""
Check whether a given local user is currently joined to the given room.
Returns:
A boolean indicating whether the user is currently joined to the room
Raises:
Exeption when called with a non-local user to this homeserver
"""
if not self.hs.is_mine_id(user_id):
raise Exception(
"Cannot call 'check_local_user_in_room' on "
"non-local user %s" % (user_id,),
)
(
membership,
member_event_id,
) = await self.get_local_current_membership_for_user_in_room(
user_id=user_id,
room_id=room_id,
)
return membership == Membership.JOIN
async def get_local_current_membership_for_user_in_room( async def get_local_current_membership_for_user_in_room(
self, user_id: str, room_id: str self, user_id: str, room_id: str
) -> Tuple[Optional[str], Optional[str]]: ) -> Tuple[Optional[str], Optional[str]]:

View File

@ -999,7 +999,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations, bundled_aggregations,
) )
self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6) self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
def test_annotation_to_annotation(self) -> None: def test_annotation_to_annotation(self) -> None:
"""Any relation to an annotation should be ignored.""" """Any relation to an annotation should be ignored."""
@ -1035,7 +1035,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations, bundled_aggregations,
) )
self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6) self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
def test_thread(self) -> None: def test_thread(self) -> None:
""" """
@ -1080,21 +1080,21 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# The "user" sent the root event and is making queries for the bundled # The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated. # aggregations: they have participated.
self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9)
# The "user2" sent replies in the thread and is making queries for the # The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated. # bundled aggregations: they have participated.
# #
# Note that this re-uses some cached values, so the total number of # Note that this re-uses some cached values, so the total number of
# queries is much smaller. # queries is much smaller.
self._test_bundled_aggregations( self._test_bundled_aggregations(
RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token
) )
# A user with no interactions with the thread: they have not participated. # A user with no interactions with the thread: they have not participated.
user3_id, user3_token = self._create_user("charlie") user3_id, user3_token = self._create_user("charlie")
self.helper.join(self.room, user=user3_id, tok=user3_token) self.helper.join(self.room, user=user3_id, tok=user3_token)
self._test_bundled_aggregations( self._test_bundled_aggregations(
RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token
) )
def test_thread_with_bundled_aggregations_for_latest(self) -> None: def test_thread_with_bundled_aggregations_for_latest(self) -> None:
@ -1142,7 +1142,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations["latest_event"].get("unsigned"), bundled_aggregations["latest_event"].get("unsigned"),
) )
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
def test_nested_thread(self) -> None: def test_nested_thread(self) -> None:
""" """