diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 257c5584d0..cd708e1c99 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -211,44 +211,50 @@ def _make_generic_sql_bound( def _filter_results( - direction: str, - from_token: Optional[RoomStreamToken], - to_token: Optional[RoomStreamToken], + lower_token: Optional[RoomStreamToken], + upper_token: Optional[RoomStreamToken], instance_name: str, + topological_ordering: int, stream_ordering: int, ) -> bool: - """Filter results from fetching events in the DB against the given tokens. + """Returns True if the event persisted by the given instance at the given + topological/stream_ordering falls between the two tokens (taking a None + token to mean unbounded). - This is necessary to handle the case where the tokens include position - maps, which we handle by fetching more than necessary from the DB and then - filtering (rather than attempting to construct a complicated SQL query). + Used to filter results from fetching events in the DB against the given + tokens. This is necessary to handle the case where the tokens include + position maps, which we handle by fetching more than necessary from the DB + and then filtering (rather than attempting to construct a complicated SQL + query). """ - # We will have already filtered by the topological tokens, so we don't - # bother checking topological token bounds again. - if from_token and from_token.topological: - from_token = None + event_historical_tuple = ( + topological_ordering, + stream_ordering, + ) - if to_token and to_token.topological: - to_token = None + if lower_token: + if lower_token.topological: + # If these are historical tokens we compare the `(topological, stream)` + # tuples. + if event_historical_tuple <= lower_token.as_historical_tuple(): + return False - lower_bound = None - if direction == "f" and from_token: - lower_bound = from_token.get_stream_pos_for_instance(instance_name) - elif direction == "b" and to_token: - lower_bound = to_token.get_stream_pos_for_instance(instance_name) + else: + # If these are live tokens we compare the stream ordering against the + # writers stream position. + if stream_ordering <= lower_token.get_stream_pos_for_instance( + instance_name + ): + return False - if lower_bound and stream_ordering <= lower_bound: - return False - - upper_bound = None - if direction == "b" and from_token: - upper_bound = from_token.get_stream_pos_for_instance(instance_name) - elif direction == "f" and to_token: - upper_bound = to_token.get_stream_pos_for_instance(instance_name) - - if upper_bound and upper_bound < stream_ordering: - return False + if upper_token: + if upper_token.topological: + if upper_token.as_historical_tuple() < event_historical_tuple: + return False + else: + if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering: + return False return True @@ -482,7 +488,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): max_to_id = to_key.get_max_stream_pos() sql = """ - SELECT event_id, instance_name, stream_ordering + SELECT event_id, instance_name, topological_ordering, stream_ordering FROM events WHERE room_id = ? @@ -496,9 +502,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): rows = [ _EventDictReturn(event_id, None, stream_ordering) - for event_id, instance_name, stream_ordering in txn + for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( - "f", from_key, to_key, instance_name, stream_ordering + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, ) ][:limit] return rows @@ -543,7 +553,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): max_to_id = to_key.get_max_stream_pos() sql = """ - SELECT m.event_id, instance_name, stream_ordering + SELECT m.event_id, instance_name, topological_ordering, stream_ordering FROM events AS e, room_memberships AS m WHERE e.event_id = m.event_id AND m.user_id = ? @@ -554,9 +564,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): rows = [ _EventDictReturn(event_id, None, stream_ordering) - for event_id, instance_name, stream_ordering in txn + for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( - "f", from_key, to_key, instance_name, stream_ordering + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, ) ] @@ -1159,7 +1173,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): _EventDictReturn(event_id, topological_ordering, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( - direction, from_token, to_token, instance_name, stream_ordering + lower_token=to_token if direction == "b" else from_token, + upper_token=from_token if direction == "b" else to_token, + instance_name=instance_name, + topological_ordering=topological_ordering, + stream_ordering=stream_ordering, ) ][:limit]