Allow filtering events for multiple users at once
							parent
							
								
									5de1563997
								
							
						
					
					
						commit
						cc66a9a5e3
					
				| 
						 | 
				
			
			@ -53,16 +53,54 @@ class BaseHandler(object):
 | 
			
		|||
        self.event_builder_factory = hs.get_event_builder_factory()
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _filter_events_for_client(self, user_id, events, is_guest=False):
 | 
			
		||||
        # Assumes that user has at some point joined the room if not is_guest.
 | 
			
		||||
    def _filter_events_for_clients(self, users, events):
 | 
			
		||||
        """ Returns dict of user_id -> list of events that user is allowed to
 | 
			
		||||
        see.
 | 
			
		||||
        """
 | 
			
		||||
        event_id_to_state = yield self.store.get_state_for_events(
 | 
			
		||||
            frozenset(e.event_id for e in events),
 | 
			
		||||
            types=(
 | 
			
		||||
                (EventTypes.RoomHistoryVisibility, ""),
 | 
			
		||||
                (EventTypes.Member, None),
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        forgotten = yield defer.gatherResults([
 | 
			
		||||
            self.store.who_forgot_in_room(
 | 
			
		||||
                room_id,
 | 
			
		||||
            )
 | 
			
		||||
            for room_id in frozenset(e.room_id for e in events)
 | 
			
		||||
        ], consumeErrors=True)
 | 
			
		||||
 | 
			
		||||
        # Set of membership event_ids that have been forgotten
 | 
			
		||||
        event_id_forgotten = frozenset(
 | 
			
		||||
            row["event_id"] for rows in forgotten for row in rows
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        def allowed(event, user_id, is_guest):
 | 
			
		||||
            state = event_id_to_state[event.event_id]
 | 
			
		||||
 | 
			
		||||
            visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
 | 
			
		||||
            if visibility_event:
 | 
			
		||||
                visibility = visibility_event.content.get("history_visibility", "shared")
 | 
			
		||||
            else:
 | 
			
		||||
                visibility = "shared"
 | 
			
		||||
 | 
			
		||||
        def allowed(event, membership, visibility):
 | 
			
		||||
            if visibility == "world_readable":
 | 
			
		||||
                return True
 | 
			
		||||
 | 
			
		||||
            if is_guest:
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
            membership_event = state.get((EventTypes.Member, user_id), None)
 | 
			
		||||
            if membership_event:
 | 
			
		||||
                if membership_event.event_id in event_id_forgotten:
 | 
			
		||||
                    membership = None
 | 
			
		||||
                else:
 | 
			
		||||
                    membership = membership_event.membership
 | 
			
		||||
            else:
 | 
			
		||||
                membership = None
 | 
			
		||||
 | 
			
		||||
            if membership == Membership.JOIN:
 | 
			
		||||
                return True
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -78,43 +116,20 @@ class BaseHandler(object):
 | 
			
		|||
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        event_id_to_state = yield self.store.get_state_for_events(
 | 
			
		||||
            frozenset(e.event_id for e in events),
 | 
			
		||||
            types=(
 | 
			
		||||
                (EventTypes.RoomHistoryVisibility, ""),
 | 
			
		||||
                (EventTypes.Member, user_id),
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        defer.returnValue({
 | 
			
		||||
            user_id: [
 | 
			
		||||
                event
 | 
			
		||||
                for event in events
 | 
			
		||||
                if allowed(event, user_id, is_guest)
 | 
			
		||||
            ]
 | 
			
		||||
            for user_id, is_guest in users
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
        events_to_return = []
 | 
			
		||||
        for event in events:
 | 
			
		||||
            state = event_id_to_state[event.event_id]
 | 
			
		||||
 | 
			
		||||
            membership_event = state.get((EventTypes.Member, user_id), None)
 | 
			
		||||
            if membership_event:
 | 
			
		||||
                was_forgotten_at_event = yield self.store.was_forgotten_at(
 | 
			
		||||
                    membership_event.state_key,
 | 
			
		||||
                    membership_event.room_id,
 | 
			
		||||
                    membership_event.event_id
 | 
			
		||||
                )
 | 
			
		||||
                if was_forgotten_at_event:
 | 
			
		||||
                    membership = None
 | 
			
		||||
                else:
 | 
			
		||||
                    membership = membership_event.membership
 | 
			
		||||
            else:
 | 
			
		||||
                membership = None
 | 
			
		||||
 | 
			
		||||
            visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
 | 
			
		||||
            if visibility_event:
 | 
			
		||||
                visibility = visibility_event.content.get("history_visibility", "shared")
 | 
			
		||||
            else:
 | 
			
		||||
                visibility = "shared"
 | 
			
		||||
 | 
			
		||||
            should_include = allowed(event, membership, visibility)
 | 
			
		||||
            if should_include:
 | 
			
		||||
                events_to_return.append(event)
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(events_to_return)
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _filter_events_for_client(self, user_id, events, is_guest=False):
 | 
			
		||||
        # Assumes that user has at some point joined the room if not is_guest.
 | 
			
		||||
        res = yield self._filter_events_for_clients([(user_id, is_guest)], events)
 | 
			
		||||
        defer.returnValue(res.get(user_id, []))
 | 
			
		||||
 | 
			
		||||
    def ratelimit(self, user_id):
 | 
			
		||||
        time_now = self.clock.time()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -287,6 +287,7 @@ class RoomMemberStore(SQLBaseStore):
 | 
			
		|||
            txn.execute(sql, (user_id, room_id))
 | 
			
		||||
        yield self.runInteraction("forget_membership", f)
 | 
			
		||||
        self.was_forgotten_at.invalidate_all()
 | 
			
		||||
        self.who_forgot_in_room.invalidate_all()
 | 
			
		||||
        self.did_forget.invalidate((user_id, room_id))
 | 
			
		||||
 | 
			
		||||
    @cachedInlineCallbacks(num_args=2)
 | 
			
		||||
| 
						 | 
				
			
			@ -336,3 +337,15 @@ class RoomMemberStore(SQLBaseStore):
 | 
			
		|||
            return rows[0][0]
 | 
			
		||||
        forgot = yield self.runInteraction("did_forget_membership_at", f)
 | 
			
		||||
        defer.returnValue(forgot == 1)
 | 
			
		||||
 | 
			
		||||
    @cached()
 | 
			
		||||
    def who_forgot_in_room(self, room_id):
 | 
			
		||||
        return self._simple_select_list(
 | 
			
		||||
            table="room_memberships",
 | 
			
		||||
            retcols=("user_id", "event_id"),
 | 
			
		||||
            keyvalues={
 | 
			
		||||
                "room_id": room_id,
 | 
			
		||||
                "forgotten": 1,
 | 
			
		||||
            },
 | 
			
		||||
            desc="who_forgot"
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue