Fix a few race conditions in the state calculation

Be a bit more careful about how we calculate the state to be returned by
/sync. In a few places, it was possible for /sync to return slightly later
state than that represented by the next_batch token and the timeline. In
particular, the following cases were susceptible:

* On a full state sync, for an active room
* During a per-room incremental sync with a timeline gap
* When the user has just joined a room. (Refactor check_joined_room to make it
  less magical)

Also, use store.get_state_for_events() (and thus the existing stategroups) to
calculate the state corresponding to a particular sync position, rather than
state_handler.compute_event_context(), which recalculates from first principles
(and tends to miss some state).

Merged from PR https://github.com/matrix-org/synapse/pull/372
pull/374/head
Richard van der Hoff 2015-11-10 18:27:23 +00:00
parent 5ab4b0afe8
commit fddedd51d9
2 changed files with 78 additions and 61 deletions

View File

@ -254,9 +254,7 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token=timeline_since_token
)
current_state = yield self.state_handler.get_current_state(
room_id
)
current_state = yield self.get_state_at(room_id, now_token)
defer.returnValue(JoinedSyncResult(
room_id=room_id,
@ -353,14 +351,12 @@ class SyncHandler(BaseHandler):
room_id, sync_config, leave_token, since_token=timeline_since_token
)
leave_state = yield self.store.get_state_for_events(
[leave_event_id], None
)
leave_state = yield self.store.get_state_for_event(leave_event_id)
defer.returnValue(ArchivedSyncResult(
room_id=room_id,
timeline=batch,
state=leave_state[leave_event_id],
state=leave_state,
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
@ -424,6 +420,9 @@ class SyncHandler(BaseHandler):
if len(room_events) <= timeline_limit:
# There is no gap in any of the rooms. Therefore we can just
# partition the new events by room and return them.
logger.debug("Got %i events for incremental sync - not limited",
len(room_events))
invite_events = []
leave_events = []
events_by_room_id = {}
@ -439,9 +438,11 @@ class SyncHandler(BaseHandler):
for room_id in joined_room_ids:
recents = events_by_room_id.get(room_id, [])
logger.debug("Events for room %s: %r", room_id, recents)
state = {
(event.type, event.state_key): event
for event in recents if event.is_state()}
limited = False
if recents:
prev_batch = now_token.copy_and_replace(
@ -450,9 +451,13 @@ class SyncHandler(BaseHandler):
else:
prev_batch = now_token
state, limited = yield self.check_joined_room(
sync_config, room_id, state
)
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
logger.debug("User has just joined %s: needs full state",
room_id)
state = yield self.get_state_at(room_id, now_token)
# the timeline is inherently limited if we've just joined
limited = True
room_sync = JoinedSyncResult(
room_id=room_id,
@ -467,10 +472,15 @@ class SyncHandler(BaseHandler):
room_id, tags_by_room
),
)
logger.debug("Result for room %s: %r", room_id, room_sync)
if room_sync:
joined.append(room_sync)
else:
logger.debug("Got %i events for incremental sync - hit limit",
len(room_events))
invite_events = yield self.store.get_invites_for_user(
sync_config.user.to_string()
)
@ -563,6 +573,8 @@ class SyncHandler(BaseHandler):
Returns:
A Deferred JoinedSyncResult
"""
logger.debug("Doing incremental sync for room %s between %s and %s",
room_id, since_token, now_token)
# TODO(mjark): Check for redactions we might have missed.
@ -572,30 +584,26 @@ class SyncHandler(BaseHandler):
logging.debug("Recents %r", batch)
# TODO(mjark): This seems racy since this isn't being passed a
# token to indicate what point in the stream this is
current_state = yield self.state_handler.get_current_state(
room_id
current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state_at_previous_sync = yield self.get_state_at_previous_sync(
room_id, since_token=since_token
)
state_events_delta = yield self.compute_state_delta(
state = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=current_state,
)
state_events_delta, _ = yield self.check_joined_room(
sync_config, room_id, state_events_delta
)
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
state = yield self.get_state_at(room_id, now_token)
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
state=state_events_delta,
state=state,
ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
@ -627,16 +635,12 @@ class SyncHandler(BaseHandler):
logging.debug("Recents %r", batch)
# TODO(mjark): This seems racy since this isn't being passed a
# token to indicate what point in the stream this is
leave_state = yield self.store.get_state_for_events(
[leave_event.event_id], None
state_events_at_leave = yield self.store.get_state_for_event(
leave_event.event_id
)
state_events_at_leave = leave_state[leave_event.event_id]
state_at_previous_sync = yield self.get_state_at_previous_sync(
leave_event.room_id, since_token=since_token
state_at_previous_sync = yield self.get_state_at(
leave_event.room_id, stream_position=since_token
)
state_events_delta = yield self.compute_state_delta(
@ -659,26 +663,36 @@ class SyncHandler(BaseHandler):
defer.returnValue(room_sync)
@defer.inlineCallbacks
def get_state_at_previous_sync(self, room_id, since_token):
""" Get the room state at the previous sync the client made.
Returns:
A Deferred map from ((type, state_key)->Event)
def get_state_after_event(self, event):
"""
Get the room state after the given event
:param synapse.events.EventBase event: event of interest
:return: A Deferred map from ((type, state_key)->Event)
"""
state = yield self.store.get_state_for_event(event.event_id)
if event.is_state():
state = state.copy()
state[(event.type, event.state_key)] = event
defer.returnValue(state)
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position):
""" Get the room state at a particular stream position
:param str room_id: room for which to get state
:param StreamToken stream_position: point at which to get state
:returns: A Deferred map from ((type, state_key)->Event)
"""
last_events, token = yield self.store.get_recent_events_for_room(
room_id, end_token=since_token.room_key, limit=1,
room_id, end_token=stream_position.room_key, limit=1,
)
if last_events:
last_event = last_events[0]
last_context = yield self.state_handler.compute_event_context(
last_event
)
if last_event.is_state():
state = last_context.current_state.copy()
state[(last_event.type, last_event.state_key)] = last_event
else:
state = last_context.current_state
last_event = last_events[-1]
state = yield self.get_state_after_event(last_event)
else:
# no events in this room - so presumably no state
state = {}
defer.returnValue(state)
@ -706,31 +720,20 @@ class SyncHandler(BaseHandler):
state_delta[key] = event
return state_delta
@defer.inlineCallbacks
def check_joined_room(self, sync_config, room_id, state_delta):
def check_joined_room(self, sync_config, state_delta):
"""
Check if the user has just joined the given room. If so, return the
full state for the room, instead of the delta since the last sync.
Check if the user has just joined the given room (so should
be given the full state)
:param sync_config:
:param room_id:
:param dict[(str,str), synapse.events.FrozenEvent] state_delta: the
difference in state since the last sync
:returns A deferred Tuple (state_delta, limited)
"""
joined = False
limited = False
join_event = state_delta.get((
EventTypes.Member, sync_config.user.to_string()), None)
if join_event is not None:
if join_event.content["membership"] == Membership.JOIN:
joined = True
if joined:
state_delta = yield self.state_handler.get_current_state(room_id)
# the timeline is inherently limited if we've just joined
limited = True
defer.returnValue((state_delta, limited))
return True
return False

View File

@ -237,6 +237,20 @@ class StateStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
def get_state_for_event(self, event_id, types=None):
"""
Get the state dict corresponding to a particular event
:param str event_id: event whose state should be returned
:param list[(str, str)]|None types: List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
:return: a deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
@cached(num_args=2, lru=True, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol(