Add `allow_departed_users` param to `check_in_room_or_world_readable`
... and set it everywhere it's called. while we're here, rename it for consistency with `check_user_in_room` (and to help check that I haven't missed any instances)pull/6949/head
parent
b58d17e44f
commit
a0a1fd0bec
|
@ -625,10 +625,18 @@ class Auth(object):
|
||||||
return query_params[0].decode("ascii")
|
return query_params[0].decode("ascii")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_in_room_or_world_readable(self, room_id, user_id):
|
def check_user_in_room_or_world_readable(
|
||||||
|
self, room_id: str, user_id: str, allow_departed_users: bool = False
|
||||||
|
):
|
||||||
"""Checks that the user is or was in the room or the room is world
|
"""Checks that the user is or was in the room or the room is world
|
||||||
readable. If it isn't then an exception is raised.
|
readable. If it isn't then an exception is raised.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: room to check
|
||||||
|
user_id: user to check
|
||||||
|
allow_departed_users: if True, accept users that were previously
|
||||||
|
members but have now departed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[tuple[str, str|None]]: Resolves to the current membership of
|
Deferred[tuple[str, str|None]]: Resolves to the current membership of
|
||||||
the user in the room and the membership event ID of the user. If
|
the user in the room and the membership event ID of the user. If
|
||||||
|
@ -643,7 +651,7 @@ class Auth(object):
|
||||||
# * The user is a guest user, and has joined the room
|
# * The user is a guest user, and has joined the room
|
||||||
# else it will throw.
|
# else it will throw.
|
||||||
member_event = yield self.check_user_in_room(
|
member_event = yield self.check_user_in_room(
|
||||||
room_id, user_id, allow_departed_users=True
|
room_id, user_id, allow_departed_users=allow_departed_users
|
||||||
)
|
)
|
||||||
return member_event.membership, member_event.event_id
|
return member_event.membership, member_event.event_id
|
||||||
except AuthError:
|
except AuthError:
|
||||||
|
@ -656,7 +664,9 @@ class Auth(object):
|
||||||
):
|
):
|
||||||
return Membership.JOIN, None
|
return Membership.JOIN, None
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
|
403,
|
||||||
|
"User %s not in room %s, and room previews are disabled"
|
||||||
|
% (user_id, room_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -277,7 +277,9 @@ class InitialSyncHandler(BaseHandler):
|
||||||
(
|
(
|
||||||
membership,
|
membership,
|
||||||
member_event_id,
|
member_event_id,
|
||||||
) = await self.auth.check_user_in_room_or_world_readable(room_id, user_id)
|
) = await self.auth.check_user_in_room_or_world_readable(
|
||||||
|
room_id, user_id, allow_departed_users=True,
|
||||||
|
)
|
||||||
is_peeking = member_event_id is None
|
is_peeking = member_event_id is None
|
||||||
|
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
|
|
|
@ -99,7 +99,9 @@ class MessageHandler(object):
|
||||||
(
|
(
|
||||||
membership,
|
membership,
|
||||||
membership_event_id,
|
membership_event_id,
|
||||||
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
) = yield self.auth.check_user_in_room_or_world_readable(
|
||||||
|
room_id, user_id, allow_departed_users=True
|
||||||
|
)
|
||||||
|
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
data = yield self.state.get_current_state(room_id, event_type, state_key)
|
data = yield self.state.get_current_state(room_id, event_type, state_key)
|
||||||
|
@ -177,7 +179,9 @@ class MessageHandler(object):
|
||||||
(
|
(
|
||||||
membership,
|
membership,
|
||||||
membership_event_id,
|
membership_event_id,
|
||||||
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
) = yield self.auth.check_user_in_room_or_world_readable(
|
||||||
|
room_id, user_id, allow_departed_users=True
|
||||||
|
)
|
||||||
|
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
state_ids = yield self.store.get_filtered_current_state_ids(
|
state_ids = yield self.store.get_filtered_current_state_ids(
|
||||||
|
@ -216,8 +220,8 @@ class MessageHandler(object):
|
||||||
if not requester.app_service:
|
if not requester.app_service:
|
||||||
# We check AS auth after fetching the room membership, as it
|
# We check AS auth after fetching the room membership, as it
|
||||||
# requires us to pull out all joined members anyway.
|
# requires us to pull out all joined members anyway.
|
||||||
membership, _ = yield self.auth.check_in_room_or_world_readable(
|
membership, _ = yield self.auth.check_user_in_room_or_world_readable(
|
||||||
room_id, user_id
|
room_id, user_id, allow_departed_users=True
|
||||||
)
|
)
|
||||||
if membership != Membership.JOIN:
|
if membership != Membership.JOIN:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|
|
@ -335,7 +335,9 @@ class PaginationHandler(object):
|
||||||
(
|
(
|
||||||
membership,
|
membership,
|
||||||
member_event_id,
|
member_event_id,
|
||||||
) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
|
) = await self.auth.check_user_in_room_or_world_readable(
|
||||||
|
room_id, user_id, allow_departed_users=True
|
||||||
|
)
|
||||||
|
|
||||||
if source_config.direction == "b":
|
if source_config.direction == "b":
|
||||||
# if we're going backwards, we might need to backfill. This
|
# if we're going backwards, we might need to backfill. This
|
||||||
|
|
|
@ -142,8 +142,8 @@ class RelationPaginationServlet(RestServlet):
|
||||||
):
|
):
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
await self.auth.check_in_room_or_world_readable(
|
await self.auth.check_user_in_room_or_world_readable(
|
||||||
room_id, requester.user.to_string()
|
room_id, requester.user.to_string(), allow_departed_users=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# This gets the original event and checks that a) the event exists and
|
# This gets the original event and checks that a) the event exists and
|
||||||
|
@ -235,8 +235,8 @@ class RelationAggregationPaginationServlet(RestServlet):
|
||||||
):
|
):
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
await self.auth.check_in_room_or_world_readable(
|
await self.auth.check_user_in_room_or_world_readable(
|
||||||
room_id, requester.user.to_string()
|
room_id, requester.user.to_string(), allow_departed_users=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This checks that a) the event exists and b) the user is allowed to
|
# This checks that a) the event exists and b) the user is allowed to
|
||||||
|
@ -313,8 +313,8 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
||||||
async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
|
async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
await self.auth.check_in_room_or_world_readable(
|
await self.auth.check_user_in_room_or_world_readable(
|
||||||
room_id, requester.user.to_string()
|
room_id, requester.user.to_string(), allow_departed_users=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This checks that a) the event exists and b) the user is allowed to
|
# This checks that a) the event exists and b) the user is allowed to
|
||||||
|
|
Loading…
Reference in New Issue