Fix up _filter_results

pull/8439/head
Erik Johnston 2020-10-06 11:44:10 +01:00
parent e91f0e9d13
commit 604e33fbb3
1 changed files with 54 additions and 36 deletions

View File

@ -211,44 +211,50 @@ def _make_generic_sql_bound(
def _filter_results( def _filter_results(
direction: str, lower_token: Optional[RoomStreamToken],
from_token: Optional[RoomStreamToken], upper_token: Optional[RoomStreamToken],
to_token: Optional[RoomStreamToken],
instance_name: str, instance_name: str,
topological_ordering: int,
stream_ordering: int, stream_ordering: int,
) -> bool: ) -> bool:
"""Filter results from fetching events in the DB against the given tokens. """Returns True if the event persisted by the given instance at the given
topological/stream_ordering falls between the two tokens (taking a None
token to mean unbounded).
This is necessary to handle the case where the tokens include position Used to filter results from fetching events in the DB against the given
maps, which we handle by fetching more than necessary from the DB and then tokens. This is necessary to handle the case where the tokens include
filtering (rather than attempting to construct a complicated SQL query). position maps, which we handle by fetching more than necessary from the DB
and then filtering (rather than attempting to construct a complicated SQL
query).
""" """
# We will have already filtered by the topological tokens, so we don't event_historical_tuple = (
# bother checking topological token bounds again. topological_ordering,
if from_token and from_token.topological: stream_ordering,
from_token = None )
if to_token and to_token.topological: if lower_token:
to_token = None if lower_token.topological:
# If these are historical tokens we compare the `(topological, stream)`
# tuples.
if event_historical_tuple <= lower_token.as_historical_tuple():
return False
lower_bound = None else:
if direction == "f" and from_token: # If these are live tokens we compare the stream ordering against the
lower_bound = from_token.get_stream_pos_for_instance(instance_name) # writers stream position.
elif direction == "b" and to_token: if stream_ordering <= lower_token.get_stream_pos_for_instance(
lower_bound = to_token.get_stream_pos_for_instance(instance_name) instance_name
):
return False
if lower_bound and stream_ordering <= lower_bound: if upper_token:
return False if upper_token.topological:
if upper_token.as_historical_tuple() < event_historical_tuple:
upper_bound = None return False
if direction == "b" and from_token: else:
upper_bound = from_token.get_stream_pos_for_instance(instance_name) if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
elif direction == "f" and to_token: return False
upper_bound = to_token.get_stream_pos_for_instance(instance_name)
if upper_bound and upper_bound < stream_ordering:
return False
return True return True
@ -482,7 +488,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
max_to_id = to_key.get_max_stream_pos() max_to_id = to_key.get_max_stream_pos()
sql = """ sql = """
SELECT event_id, instance_name, stream_ordering SELECT event_id, instance_name, topological_ordering, stream_ordering
FROM events FROM events
WHERE WHERE
room_id = ? room_id = ?
@ -496,9 +502,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
rows = [ rows = [
_EventDictReturn(event_id, None, stream_ordering) _EventDictReturn(event_id, None, stream_ordering)
for event_id, instance_name, stream_ordering in txn for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results( if _filter_results(
"f", from_key, to_key, instance_name, stream_ordering from_key,
to_key,
instance_name,
topological_ordering,
stream_ordering,
) )
][:limit] ][:limit]
return rows return rows
@ -543,7 +553,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
max_to_id = to_key.get_max_stream_pos() max_to_id = to_key.get_max_stream_pos()
sql = """ sql = """
SELECT m.event_id, instance_name, stream_ordering SELECT m.event_id, instance_name, topological_ordering, stream_ordering
FROM events AS e, room_memberships AS m FROM events AS e, room_memberships AS m
WHERE e.event_id = m.event_id WHERE e.event_id = m.event_id
AND m.user_id = ? AND m.user_id = ?
@ -554,9 +564,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
rows = [ rows = [
_EventDictReturn(event_id, None, stream_ordering) _EventDictReturn(event_id, None, stream_ordering)
for event_id, instance_name, stream_ordering in txn for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results( if _filter_results(
"f", from_key, to_key, instance_name, stream_ordering from_key,
to_key,
instance_name,
topological_ordering,
stream_ordering,
) )
] ]
@ -1159,7 +1173,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
_EventDictReturn(event_id, topological_ordering, stream_ordering) _EventDictReturn(event_id, topological_ordering, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results( if _filter_results(
direction, from_token, to_token, instance_name, stream_ordering lower_token=to_token if direction == "b" else from_token,
upper_token=from_token if direction == "b" else to_token,
instance_name=instance_name,
topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
) )
][:limit] ][:limit]