Wire up new token when fetching events streams
parent
b2172da9bb
commit
d7da8ca8a8
|
@ -53,6 +53,7 @@ from synapse.storage.database import (
|
|||
)
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
@ -209,6 +210,49 @@ def _make_generic_sql_bound(
|
|||
)
|
||||
|
||||
|
||||
def _filter_results(
|
||||
direction: str,
|
||||
from_token: Optional[RoomStreamToken],
|
||||
to_token: Optional[RoomStreamToken],
|
||||
instance_name: str,
|
||||
stream_ordering: int,
|
||||
) -> bool:
|
||||
"""Filter results from fetching events in the DB against the given tokens.
|
||||
|
||||
This is necessary to handle the case where the tokens include positions
|
||||
maps, which we handle by fetching more than necessary from the DB and then
|
||||
filtering (rather than attempting to construct a complicated SQL query).
|
||||
"""
|
||||
|
||||
# We will have already filtered by the topological tokens, so we don't
|
||||
# bother checking topological token bounds again.
|
||||
if from_token and from_token.topological:
|
||||
from_token = None
|
||||
|
||||
if to_token and to_token.topological:
|
||||
to_token = None
|
||||
|
||||
lower_bound = None
|
||||
if direction == "f" and from_token:
|
||||
lower_bound = from_token.get_stream_pos_for_instance(instance_name)
|
||||
elif direction == "b" and to_token:
|
||||
lower_bound = to_token.get_stream_pos_for_instance(instance_name)
|
||||
|
||||
if lower_bound and stream_ordering <= lower_bound:
|
||||
return False
|
||||
|
||||
upper_bound = None
|
||||
if direction == "b" and from_token:
|
||||
upper_bound = from_token.get_stream_pos_for_instance(instance_name)
|
||||
elif direction == "f" and to_token:
|
||||
upper_bound = to_token.get_stream_pos_for_instance(instance_name)
|
||||
|
||||
if upper_bound and upper_bound < stream_ordering:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
|
||||
# NB: This may create SQL clauses that don't optimise well (and we don't
|
||||
# have indices on all possible clauses). E.g. it may create
|
||||
|
@ -306,7 +350,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
raise NotImplementedError()
|
||||
|
||||
def get_room_max_token(self) -> RoomStreamToken:
|
||||
return RoomStreamToken(None, self.get_room_max_stream_ordering())
|
||||
min_pos = self._stream_id_gen.get_current_token()
|
||||
|
||||
positions = {}
|
||||
if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
|
||||
# The `min_pos` is the minimum position that we know all instances
|
||||
# have finished persisting to, so we only care about instances whose
|
||||
# positions are ahead of that. (Instance positions can be behind the
|
||||
# min position as there are times we can work out that the minimum
|
||||
# position is ahead of the naive minimum across all current
|
||||
# positions. See MultiWriterIdGenerator for details)
|
||||
positions = {
|
||||
i: p
|
||||
for i, p in self._stream_id_gen.get_positions().items()
|
||||
if p > min_pos
|
||||
}
|
||||
|
||||
if set(positions.values()) == {min_pos}:
|
||||
positions = {}
|
||||
|
||||
return RoomStreamToken(None, min_pos, positions)
|
||||
|
||||
async def get_room_events_stream_for_rooms(
|
||||
self,
|
||||
|
@ -405,25 +468,39 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
if from_key == to_key:
|
||||
return [], from_key
|
||||
|
||||
from_id = from_key.stream
|
||||
to_id = to_key.stream
|
||||
|
||||
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
|
||||
has_changed = self._events_stream_cache.has_entity_changed(
|
||||
room_id, from_key.stream
|
||||
)
|
||||
|
||||
if not has_changed:
|
||||
return [], from_key
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT event_id, stream_ordering FROM events WHERE"
|
||||
" room_id = ?"
|
||||
" AND not outlier"
|
||||
" AND stream_ordering > ? AND stream_ordering <= ?"
|
||||
" ORDER BY stream_ordering %s LIMIT ?"
|
||||
) % (order,)
|
||||
txn.execute(sql, (room_id, from_id, to_id, limit))
|
||||
# To handle tokens with a non-empty instance_map we fetch more
|
||||
# results than necessary and the filter down
|
||||
min_from_id = from_key.stream
|
||||
max_to_id = to_key.get_max_stream_pos()
|
||||
|
||||
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
|
||||
sql = """
|
||||
SELECT event_id, instance_name, stream_ordering
|
||||
FROM events
|
||||
WHERE
|
||||
room_id = ?
|
||||
AND not outlier
|
||||
AND stream_ordering > ? AND stream_ordering <= ?
|
||||
ORDER BY stream_ordering %s LIMIT ?
|
||||
""" % (
|
||||
order,
|
||||
)
|
||||
txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
|
||||
|
||||
rows = [
|
||||
_EventDictReturn(event_id, None, stream_ordering)
|
||||
for event_id, instance_name, stream_ordering in txn
|
||||
if _filter_results(
|
||||
"f", from_key, to_key, instance_name, stream_ordering
|
||||
)
|
||||
][:limit]
|
||||
return rows
|
||||
|
||||
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
|
||||
|
@ -432,7 +509,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
[r.event_id for r in rows], get_prev_content=True
|
||||
)
|
||||
|
||||
self._set_before_and_after(ret, rows, topo_order=from_id is None)
|
||||
self._set_before_and_after(ret, rows, topo_order=from_key.stream is None)
|
||||
|
||||
if order.lower() == "desc":
|
||||
ret.reverse()
|
||||
|
@ -449,31 +526,39 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
async def get_membership_changes_for_user(
|
||||
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
|
||||
) -> List[EventBase]:
|
||||
from_id = from_key.stream
|
||||
to_id = to_key.stream
|
||||
|
||||
if from_key == to_key:
|
||||
return []
|
||||
|
||||
if from_id:
|
||||
if from_key:
|
||||
has_changed = self._membership_stream_cache.has_entity_changed(
|
||||
user_id, int(from_id)
|
||||
user_id, int(from_key.stream)
|
||||
)
|
||||
if not has_changed:
|
||||
return []
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT m.event_id, stream_ordering FROM events AS e,"
|
||||
" room_memberships AS m"
|
||||
" WHERE e.event_id = m.event_id"
|
||||
" AND m.user_id = ?"
|
||||
" AND e.stream_ordering > ? AND e.stream_ordering <= ?"
|
||||
" ORDER BY e.stream_ordering ASC"
|
||||
)
|
||||
txn.execute(sql, (user_id, from_id, to_id))
|
||||
# To handle tokens with a non-empty instance_map we fetch more
|
||||
# results than necessary and the filter down
|
||||
min_from_id = from_key.stream
|
||||
max_to_id = to_key.get_max_stream_pos()
|
||||
|
||||
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
|
||||
sql = """
|
||||
SELECT m.event_id, instance_name, stream_ordering
|
||||
FROM events AS e, room_memberships AS m
|
||||
WHERE e.event_id = m.event_id
|
||||
AND m.user_id = ?
|
||||
AND e.stream_ordering > ? AND e.stream_ordering <= ?
|
||||
ORDER BY e.stream_ordering ASC
|
||||
"""
|
||||
txn.execute(sql, (user_id, min_from_id, max_to_id,))
|
||||
|
||||
rows = [
|
||||
_EventDictReturn(event_id, None, stream_ordering)
|
||||
for event_id, instance_name, stream_ordering in txn
|
||||
if _filter_results(
|
||||
"f", from_key, to_key, instance_name, stream_ordering
|
||||
)
|
||||
]
|
||||
|
||||
return rows
|
||||
|
||||
|
@ -980,11 +1065,44 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
else:
|
||||
order = "ASC"
|
||||
|
||||
# The bounds for the stream tokens are complicated by the fact the fact
|
||||
# that we need to handle the instance_map part of the tokens. We do this
|
||||
# by fetching all events between the min stream token and the maximum
|
||||
# stream token (as return by `RoomStreamToken.get_max_stream_pos`) and
|
||||
# then filtering the results.
|
||||
if from_token.topological is not None:
|
||||
from_bound = from_token.as_tuple()
|
||||
elif direction == "b":
|
||||
from_bound = (
|
||||
None,
|
||||
from_token.get_max_stream_pos(),
|
||||
)
|
||||
else:
|
||||
from_bound = (
|
||||
None,
|
||||
from_token.stream,
|
||||
)
|
||||
|
||||
to_bound = None
|
||||
if to_token:
|
||||
if to_token.topological is not None:
|
||||
to_bound = to_token.as_tuple()
|
||||
elif direction == "b":
|
||||
to_bound = (
|
||||
None,
|
||||
to_token.stream,
|
||||
)
|
||||
else:
|
||||
to_bound = (
|
||||
None,
|
||||
to_token.get_max_stream_pos(),
|
||||
)
|
||||
|
||||
bounds = generate_pagination_where_clause(
|
||||
direction=direction,
|
||||
column_names=("topological_ordering", "stream_ordering"),
|
||||
from_token=from_token.as_tuple(),
|
||||
to_token=to_token.as_tuple() if to_token else None,
|
||||
from_token=from_bound,
|
||||
to_token=to_bound,
|
||||
engine=self.database_engine,
|
||||
)
|
||||
|
||||
|
@ -994,7 +1112,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
bounds += " AND " + filter_clause
|
||||
args.extend(filter_args)
|
||||
|
||||
args.append(int(limit))
|
||||
# We fetch more events as we'll filter the result set
|
||||
args.append(int(limit) * 2)
|
||||
|
||||
select_keywords = "SELECT"
|
||||
join_clause = ""
|
||||
|
@ -1016,7 +1135,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
select_keywords += "DISTINCT"
|
||||
|
||||
sql = """
|
||||
%(select_keywords)s event_id, topological_ordering, stream_ordering
|
||||
%(select_keywords)s
|
||||
event_id, instance_name,
|
||||
topological_ordering, stream_ordering
|
||||
FROM events
|
||||
%(join_clause)s
|
||||
WHERE outlier = ? AND room_id = ? AND %(bounds)s
|
||||
|
@ -1031,7 +1152,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
|
||||
txn.execute(sql, args)
|
||||
|
||||
rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
|
||||
# Filter the result set.
|
||||
rows = [
|
||||
_EventDictReturn(event_id, topological_ordering, stream_ordering)
|
||||
for event_id, instance_name, topological_ordering, stream_ordering in txn
|
||||
if _filter_results(
|
||||
direction, from_token, to_token, instance_name, stream_ordering
|
||||
)
|
||||
][:limit]
|
||||
|
||||
if rows:
|
||||
topo = rows[-1].topological_ordering
|
||||
|
|
|
@ -488,6 +488,21 @@ class RoomStreamToken:
|
|||
def as_tuple(self) -> Tuple[Optional[int], int]:
|
||||
return (self.topological, self.stream)
|
||||
|
||||
def get_stream_pos_for_instance(self, instance_name: str) -> int:
|
||||
"""Get the stream position for the instance
|
||||
"""
|
||||
return self.instance_map.get(instance_name, self.stream)
|
||||
|
||||
def get_max_stream_pos(self) -> int:
|
||||
"""Get the maximum stream position referenced in this token.
|
||||
|
||||
The corresponding "min" position is, by definition just `self.stream`.
|
||||
|
||||
This is used to handle tokens that have non-empty `instance_map`, and so
|
||||
reference stream positions after the `self.stream` position.
|
||||
"""
|
||||
return max(self.instance_map.values(), default=self.stream)
|
||||
|
||||
async def to_string(self, store: "DataStore") -> str:
|
||||
if self.topological is not None:
|
||||
return "t%d-%d" % (self.topological, self.stream)
|
||||
|
|
Loading…
Reference in New Issue