Add `allow_departed_users` param to `check_in_room_or_world_readable`

... and set it everywhere it's called.

while we're here, rename it for consistency with `check_user_in_room` (and to
help check that I haven't missed any instances)
pull/6949/head
Richard van der Hoff 2020-02-18 23:14:57 +00:00
parent b58d17e44f
commit a0a1fd0bec
5 changed files with 33 additions and 15 deletions

View File

@ -625,10 +625,18 @@ class Auth(object):
return query_params[0].decode("ascii") return query_params[0].decode("ascii")
@defer.inlineCallbacks @defer.inlineCallbacks
def check_in_room_or_world_readable(self, room_id, user_id): def check_user_in_room_or_world_readable(
self, room_id: str, user_id: str, allow_departed_users: bool = False
):
"""Checks that the user is or was in the room or the room is world """Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised. readable. If it isn't then an exception is raised.
Args:
room_id: room to check
user_id: user to check
allow_departed_users: if True, accept users that were previously
members but have now departed
Returns: Returns:
Deferred[tuple[str, str|None]]: Resolves to the current membership of Deferred[tuple[str, str|None]]: Resolves to the current membership of
the user in the room and the membership event ID of the user. If the user in the room and the membership event ID of the user. If
@ -643,7 +651,7 @@ class Auth(object):
# * The user is a guest user, and has joined the room # * The user is a guest user, and has joined the room
# else it will throw. # else it will throw.
member_event = yield self.check_user_in_room( member_event = yield self.check_user_in_room(
room_id, user_id, allow_departed_users=True room_id, user_id, allow_departed_users=allow_departed_users
) )
return member_event.membership, member_event.event_id return member_event.membership, member_event.event_id
except AuthError: except AuthError:
@ -656,7 +664,9 @@ class Auth(object):
): ):
return Membership.JOIN, None return Membership.JOIN, None
raise AuthError( raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN 403,
"User %s not in room %s, and room previews are disabled"
% (user_id, room_id),
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -277,7 +277,9 @@ class InitialSyncHandler(BaseHandler):
( (
membership, membership,
member_event_id, member_event_id,
) = await self.auth.check_user_in_room_or_world_readable(room_id, user_id) ) = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True,
)
is_peeking = member_event_id is None is_peeking = member_event_id is None
if membership == Membership.JOIN: if membership == Membership.JOIN:

View File

@ -99,7 +99,9 @@ class MessageHandler(object):
( (
membership, membership,
membership_event_id, membership_event_id,
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) ) = yield self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
if membership == Membership.JOIN: if membership == Membership.JOIN:
data = yield self.state.get_current_state(room_id, event_type, state_key) data = yield self.state.get_current_state(room_id, event_type, state_key)
@ -177,7 +179,9 @@ class MessageHandler(object):
( (
membership, membership,
membership_event_id, membership_event_id,
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) ) = yield self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
if membership == Membership.JOIN: if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids( state_ids = yield self.store.get_filtered_current_state_ids(
@ -216,8 +220,8 @@ class MessageHandler(object):
if not requester.app_service: if not requester.app_service:
# We check AS auth after fetching the room membership, as it # We check AS auth after fetching the room membership, as it
# requires us to pull out all joined members anyway. # requires us to pull out all joined members anyway.
membership, _ = yield self.auth.check_in_room_or_world_readable( membership, _ = yield self.auth.check_user_in_room_or_world_readable(
room_id, user_id room_id, user_id, allow_departed_users=True
) )
if membership != Membership.JOIN: if membership != Membership.JOIN:
raise NotImplementedError( raise NotImplementedError(

View File

@ -335,7 +335,9 @@ class PaginationHandler(object):
( (
membership, membership,
member_event_id, member_event_id,
) = await self.auth.check_in_room_or_world_readable(room_id, user_id) ) = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
if source_config.direction == "b": if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This # if we're going backwards, we might need to backfill. This

View File

@ -142,8 +142,8 @@ class RelationPaginationServlet(RestServlet):
): ):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_in_room_or_world_readable( await self.auth.check_user_in_room_or_world_readable(
room_id, requester.user.to_string() room_id, requester.user.to_string(), allow_departed_users=True
) )
# This gets the original event and checks that a) the event exists and # This gets the original event and checks that a) the event exists and
@ -235,8 +235,8 @@ class RelationAggregationPaginationServlet(RestServlet):
): ):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_in_room_or_world_readable( await self.auth.check_user_in_room_or_world_readable(
room_id, requester.user.to_string() room_id, requester.user.to_string(), allow_departed_users=True,
) )
# This checks that a) the event exists and b) the user is allowed to # This checks that a) the event exists and b) the user is allowed to
@ -313,8 +313,8 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_in_room_or_world_readable( await self.auth.check_user_in_room_or_world_readable(
room_id, requester.user.to_string() room_id, requester.user.to_string(), allow_departed_users=True,
) )
# This checks that a) the event exists and b) the user is allowed to # This checks that a) the event exists and b) the user is allowed to