diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1d0f0058a2..b2b9b928c9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -823,15 +823,17 @@ class SyncHandler(BaseHandler): # TODO(mjark) Check for new redactions in the state events. with Measure(self.clock, "compute_state_delta"): + current_state = yield self.get_state_at( + room_id, stream_position=now_token + ) + if full_state: if batch: state = yield self.store.get_state_for_event( batch.events[0].event_id ) else: - state = yield self.get_state_at( - room_id, stream_position=now_token - ) + state = current_state timeline_state = { (event.type, event.state_key): event @@ -842,6 +844,7 @@ class SyncHandler(BaseHandler): timeline_contains=timeline_state, timeline_start=state, previous={}, + current=current_state, ) elif batch.limited: state_at_previous_sync = yield self.get_state_at( @@ -861,6 +864,7 @@ class SyncHandler(BaseHandler): timeline_contains=timeline_state, timeline_start=state_at_timeline_start, previous=state_at_previous_sync, + current=current_state, ) else: state = {} @@ -920,7 +924,7 @@ def _action_has_highlight(actions): return False -def _calculate_state(timeline_contains, timeline_start, previous): +def _calculate_state(timeline_contains, timeline_start, previous, current): """Works out what state to include in a sync response. Args: @@ -928,6 +932,7 @@ def _calculate_state(timeline_contains, timeline_start, previous): timeline_start (dict): state at the start of the timeline previous (dict): state at the end of the previous sync (or empty dict if this is an initial sync) + current (dict): state at the end of the timeline Returns: dict @@ -938,14 +943,16 @@ def _calculate_state(timeline_contains, timeline_start, previous): timeline_contains.values(), previous.values(), timeline_start.values(), + current.values(), ) } + c_ids = set(e.event_id for e in current.values()) tc_ids = set(e.event_id for e in timeline_contains.values()) p_ids = set(e.event_id for e in previous.values()) ts_ids = set(e.event_id for e in timeline_start.values()) - state_ids = (ts_ids - p_ids) - tc_ids + state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids evs = (event_id_to_state[e] for e in state_ids) return {