diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 0c0d678562..9b614a12bb 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -104,6 +104,20 @@ class Auth(object): @defer.inlineCallbacks def check_joined_room(self, room_id, user_id, current_state=None): + """Check if the user is currently joined in the room + Args: + room_id(str): The room to check. + user_id(str): The user to check. + current_state(dict): Optional map of the current state of the room. + If provided then that map is used to check whether they are a + member of the room. Otherwise the current membership is + loaded from the database. + Raises: + AuthError if the user is not in the room. + Returns: + A deferred membership event for the user if the user is in + the room. + """ if current_state: member = current_state.get( (EventTypes.Member, user_id), @@ -119,6 +133,41 @@ class Auth(object): self._check_joined_room(member, user_id, room_id) defer.returnValue(member) + @defer.inlineCallbacks + def check_user_was_in_room(self, room_id, user_id, current_state=None): + """Check if the user was in the room at some point. + Args: + room_id(str): The room to check. + user_id(str): The user to check. + current_state(dict): Optional map of the current state of the room. + If provided then that map is used to check whether they are a + member of the room. Otherwise the current membership is + loaded from the database. + Raises: + AuthError if the user was never in the room. + Returns: + A deferred membership event for the user if the user was in + the room. + """ + if current_state: + member = current_state.get( + (EventTypes.Member, user_id), + None + ) + else: + member = yield self.state.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + + if not member: + raise AuthError(403, "User %s not in room %s" % ( + user_id, room_id + )) + + defer.returnValue(member) + @defer.inlineCallbacks def check_host_in_room(self, room_id, host): curr_state = yield self.state.get_current_state(room_id) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5447c97e83..fc9a234333 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -22,7 +22,7 @@ from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError from synapse.util.logcontext import PreserveLoggingContext -from synapse.types import UserID, RoomStreamToken +from synapse.types import UserID, RoomStreamToken, StreamToken from ._base import BaseHandler @@ -377,7 +377,6 @@ class MessageHandler(BaseHandler): lambda states: states[event.event_id] ) - (messages, token), current_state = yield defer.gatherResults( [ self.store.get_recent_events_for_room( @@ -434,13 +433,83 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def room_initial_sync(self, user_id, room_id, pagin_config=None, feedback=False): - current_state = yield self.state.get_current_state( - room_id=room_id, + """Capture the a snapshot of a room. If user is currently a member of + the room this will be what is currently in the room. If the user left + the room this will be what was in the room when they left. + + Args: + user_id(str): The user to get a snapshot for. + room_id(str): The room to get a snapshot of. + pagin_config(synapse.api.streams.PaginationConfig): The pagination + config used to determine how many messages to return. + Raises: + AuthError if the user wasn't in the room. + Returns: + A JSON object with the snapshot of the room. + """ + + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + + if member_event.membership == Membership.JOIN: + result = yield self._room_initial_sync_joined( + user_id, room_id, pagin_config, member_event + ) + elif member_event.membership == Membership.LEAVE: + result = yield self._room_initial_sync_parted( + user_id, room_id, pagin_config, member_event + ) + defer.returnValue(result) + + @defer.inlineCallbacks + def _room_initial_sync_parted(self, user_id, room_id, pagin_config, + member_event): + room_state = yield self.store.get_state_for_events( + member_event.room_id, [member_event.event_id], None ) - yield self.auth.check_joined_room( - room_id, user_id, - current_state=current_state + room_state = room_state[member_event.event_id] + + limit = pagin_config.limit if pagin_config else None + if limit is None: + limit = 10 + + stream_token = yield self.store.get_stream_token_for_event( + member_event.event_id + ) + + messages, token = yield self.store.get_recent_events_for_room( + room_id, + limit=limit, + end_token=stream_token + ) + + messages = yield self._filter_events_for_client( + user_id, room_id, messages + ) + + start_token = StreamToken(token[0], 0, 0, 0) + end_token = StreamToken(token[1], 0, 0, 0) + + time_now = self.clock.time_msec() + + defer.returnValue({ + "membership": member_event.membership, + "room_id": room_id, + "messages": { + "chunk": [serialize_event(m, time_now) for m in messages], + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": [serialize_event(s, time_now) for s in room_state.values()], + "presence": [], + "receipts": [], + }) + + @defer.inlineCallbacks + def _room_initial_sync_joined(self, user_id, room_id, pagin_config, + member_event): + current_state = yield self.state.get_current_state( + room_id=room_id, ) # TODO(paul): I wish I was called with user objects not user_id @@ -454,8 +523,6 @@ class MessageHandler(BaseHandler): for x in current_state.values() ] - member_event = current_state.get((EventTypes.Member, user_id,)) - now_token = yield self.hs.get_event_sources().get_current_token() limit = pagin_config.limit if pagin_config else None diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d7fe423f5a..0abfa86cd2 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -379,6 +379,21 @@ class StreamStore(SQLBaseStore): ) defer.returnValue("t%d-%d" % (topo, token)) + def get_stream_token_for_event(self, event_id): + """The stream token for an event + Args: + event_id(str): The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A deferred "s%d" stream token. + """ + return self._simple_select_one_onecol( + table="events", + keyvalues={"event_id": event_id}, + retcol="stream_ordering", + ).addCallback(lambda stream_ordering: "s%d" % (stream_ordering,)) + def _get_max_topological_txn(self, txn): txn.execute( "SELECT MAX(topological_ordering) FROM events"