Fix up _filter_results

pull/8439/head
Erik Johnston 2020-10-06 11:44:10 +01:00
parent e91f0e9d13
commit 604e33fbb3
1 changed files with 54 additions and 36 deletions

View File

@ -211,44 +211,50 @@ def _make_generic_sql_bound(
def _filter_results(
direction: str,
from_token: Optional[RoomStreamToken],
to_token: Optional[RoomStreamToken],
lower_token: Optional[RoomStreamToken],
upper_token: Optional[RoomStreamToken],
instance_name: str,
topological_ordering: int,
stream_ordering: int,
) -> bool:
"""Filter results from fetching events in the DB against the given tokens.
"""Returns True if the event persisted by the given instance at the given
topological/stream_ordering falls between the two tokens (taking a None
token to mean unbounded).
This is necessary to handle the case where the tokens include position
maps, which we handle by fetching more than necessary from the DB and then
filtering (rather than attempting to construct a complicated SQL query).
Used to filter results from fetching events in the DB against the given
tokens. This is necessary to handle the case where the tokens include
position maps, which we handle by fetching more than necessary from the DB
and then filtering (rather than attempting to construct a complicated SQL
query).
"""
# We will have already filtered by the topological tokens, so we don't
# bother checking topological token bounds again.
if from_token and from_token.topological:
from_token = None
event_historical_tuple = (
topological_ordering,
stream_ordering,
)
if to_token and to_token.topological:
to_token = None
if lower_token:
if lower_token.topological:
# If these are historical tokens we compare the `(topological, stream)`
# tuples.
if event_historical_tuple <= lower_token.as_historical_tuple():
return False
lower_bound = None
if direction == "f" and from_token:
lower_bound = from_token.get_stream_pos_for_instance(instance_name)
elif direction == "b" and to_token:
lower_bound = to_token.get_stream_pos_for_instance(instance_name)
else:
# If these are live tokens we compare the stream ordering against the
# writers stream position.
if stream_ordering <= lower_token.get_stream_pos_for_instance(
instance_name
):
return False
if lower_bound and stream_ordering <= lower_bound:
return False
upper_bound = None
if direction == "b" and from_token:
upper_bound = from_token.get_stream_pos_for_instance(instance_name)
elif direction == "f" and to_token:
upper_bound = to_token.get_stream_pos_for_instance(instance_name)
if upper_bound and upper_bound < stream_ordering:
return False
if upper_token:
if upper_token.topological:
if upper_token.as_historical_tuple() < event_historical_tuple:
return False
else:
if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
return False
return True
@ -482,7 +488,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
max_to_id = to_key.get_max_stream_pos()
sql = """
SELECT event_id, instance_name, stream_ordering
SELECT event_id, instance_name, topological_ordering, stream_ordering
FROM events
WHERE
room_id = ?
@ -496,9 +502,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
rows = [
_EventDictReturn(event_id, None, stream_ordering)
for event_id, instance_name, stream_ordering in txn
for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results(
"f", from_key, to_key, instance_name, stream_ordering
from_key,
to_key,
instance_name,
topological_ordering,
stream_ordering,
)
][:limit]
return rows
@ -543,7 +553,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
max_to_id = to_key.get_max_stream_pos()
sql = """
SELECT m.event_id, instance_name, stream_ordering
SELECT m.event_id, instance_name, topological_ordering, stream_ordering
FROM events AS e, room_memberships AS m
WHERE e.event_id = m.event_id
AND m.user_id = ?
@ -554,9 +564,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
rows = [
_EventDictReturn(event_id, None, stream_ordering)
for event_id, instance_name, stream_ordering in txn
for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results(
"f", from_key, to_key, instance_name, stream_ordering
from_key,
to_key,
instance_name,
topological_ordering,
stream_ordering,
)
]
@ -1159,7 +1173,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
_EventDictReturn(event_id, topological_ordering, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results(
direction, from_token, to_token, instance_name, stream_ordering
lower_token=to_token if direction == "b" else from_token,
upper_token=from_token if direction == "b" else to_token,
instance_name=instance_name,
topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
)
][:limit]