Make sync not pull out full state

pull/1047/head
Erik Johnston 2016-08-25 18:59:44 +01:00
parent 7356d52e73
commit 778fa85f47
2 changed files with 74 additions and 34 deletions

View File

@ -355,11 +355,11 @@ class SyncHandler(object):
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
""" """
state = yield self.store.get_state_for_event(event.event_id) state_ids = yield self.store.get_state_ids_for_event(event.event_id)
if event.is_state(): if event.is_state():
state = state.copy() state_ids = state_ids.copy()
state[(event.type, event.state_key)] = event state_ids[(event.type, event.state_key)] = event.event_id
defer.returnValue(state) defer.returnValue(state_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at(self, room_id, stream_position): def get_state_at(self, room_id, stream_position):
@ -412,57 +412,61 @@ class SyncHandler(object):
with Measure(self.clock, "compute_state_delta"): with Measure(self.clock, "compute_state_delta"):
if full_state: if full_state:
if batch: if batch:
current_state = yield self.store.get_state_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id batch.events[-1].event_id
) )
state = yield self.store.get_state_for_event( state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id batch.events[0].event_id
) )
else: else:
current_state = yield self.get_state_at( current_state_ids = yield self.get_state_at(
room_id, stream_position=now_token room_id, stream_position=now_token
) )
state = current_state state_ids = current_state_ids
timeline_state = { timeline_state = {
(event.type, event.state_key): event (event.type, event.state_key): event.event_id
for event in batch.events if event.is_state() for event in batch.events if event.is_state()
} }
state = _calculate_state( state_ids = _calculate_state(
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state, timeline_start=state_ids,
previous={}, previous={},
current=current_state, current=current_state_ids,
) )
elif batch.limited: elif batch.limited:
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token room_id, stream_position=since_token
) )
current_state = yield self.store.get_state_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id batch.events[-1].event_id
) )
state_at_timeline_start = yield self.store.get_state_for_event( state_at_timeline_start = yield self.store.get_state_ids_for_event(
batch.events[0].event_id batch.events[0].event_id
) )
timeline_state = { timeline_state = {
(event.type, event.state_key): event (event.type, event.state_key): event.event_id
for event in batch.events if event.is_state() for event in batch.events if event.is_state()
} }
state = _calculate_state( state_ids = _calculate_state(
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state_at_timeline_start, timeline_start=state_at_timeline_start,
previous=state_at_previous_sync, previous=state_at_previous_sync,
current=current_state, current=current_state_ids,
) )
else: else:
state = {} state_ids = {}
state = {}
if state_ids:
state = yield self.store.get_events(state_ids.values())
defer.returnValue({ defer.returnValue({
(e.type, e.state_key): e (e.type, e.state_key): e
@ -766,8 +770,13 @@ class SyncHandler(object):
# the last sync (even if we have since left). This is to make sure # the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary # we do send down the room, and with full state, where necessary
if room_id in joined_room_ids or has_join: if room_id in joined_room_ids or has_join:
old_state = yield self.get_state_at(room_id, since_token) old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev = old_state.get((EventTypes.Member, user_id), None) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None
if old_mem_ev_id:
old_mem_ev = yield self.store.get_event(
old_mem_ev_id, allow_none=True
)
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id) newly_joined_rooms.append(room_id)
@ -1059,27 +1068,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
Returns: Returns:
dict dict
""" """
event_id_to_state = { event_id_to_key = {
e.event_id: e e: key
for e in itertools.chain( for key, e in itertools.chain(
timeline_contains.values(), timeline_contains.items(),
previous.values(), previous.items(),
timeline_start.values(), timeline_start.items(),
current.values(), current.items(),
) )
} }
c_ids = set(e.event_id for e in current.values()) c_ids = set(e for e in current.values())
tc_ids = set(e.event_id for e in timeline_contains.values()) tc_ids = set(e for e in timeline_contains.values())
p_ids = set(e.event_id for e in previous.values()) p_ids = set(e for e in previous.values())
ts_ids = set(e.event_id for e in timeline_start.values()) ts_ids = set(e for e in timeline_start.values())
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids)
return { return {
(e.type, e.state_key): e event_id_to_key[e]: e for e in state_ids
for e in evs
} }

View File

@ -283,6 +283,22 @@ class StateStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, types):
event_to_groups = yield self._get_state_group_for_events(
event_ids,
)
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups, types)
event_to_state = {
event_id: group_to_state[group]
for event_id, group in event_to_groups.items()
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_event(self, event_id, types=None): def get_state_for_event(self, event_id, types=None):
""" """
@ -300,6 +316,23 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_for_events([event_id], types) state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, types=None):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
types(list[(str, str)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], types)
defer.returnValue(state_map[event_id])
@cached(num_args=2, max_entries=10000) @cached(num_args=2, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id): def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(