Make `StreamToken.room_key` be a `RoomStreamToken` instance. (#8281)

pull/8294/head
Erik Johnston 2020-09-11 12:22:55 +01:00 committed by GitHub
parent c312ee3cde
commit fe8ed1b46f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 114 additions and 99 deletions

1
changelog.d/8281.misc Normal file
View File

@ -0,0 +1 @@
Change `StreamToken.room_key` to be a `RoomStreamToken` instance.

View File

@ -46,10 +46,12 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
synapse/storage/databases/main/events.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
synapse/storage/engines,
synapse/storage/persist_events.py,
synapse/storage/state.py,
synapse/storage/util,
synapse/streams,

View File

@ -125,8 +125,8 @@ class AdminHandler(BaseHandler):
else:
stream_ordering = room.stream_ordering
from_key = str(RoomStreamToken(0, 0))
to_key = str(RoomStreamToken(None, stream_ordering))
from_key = RoomStreamToken(0, 0)
to_key = RoomStreamToken(None, stream_ordering)
written_events = set() # Events that we've processed in this room
@ -153,7 +153,7 @@ class AdminHandler(BaseHandler):
if not events:
break
from_key = events[-1].internal_metadata.after
from_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
events = await filter_events_for_client(self.storage, user_id, events)

View File

@ -29,6 +29,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import (
RoomStreamToken,
StreamToken,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
@ -104,18 +105,15 @@ class DeviceWorkerHandler(BaseHandler):
@trace
@measure_func("device.get_user_ids_changed")
async def get_user_ids_changed(self, user_id, from_token):
async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
Args:
user_id (str)
from_token (StreamToken)
"""
set_tag("user_id", user_id)
set_tag("from_token", from_token)
now_room_key = await self.store.get_room_events_max_id()
now_room_id = self.store.get_room_max_stream_ordering()
now_room_key = RoomStreamToken(None, now_room_id)
room_ids = await self.store.get_rooms_for_user(user_id)
@ -142,7 +140,7 @@ class DeviceWorkerHandler(BaseHandler):
)
rooms_changed.update(event.room_id for event in member_events)
stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream
stream_ordering = from_token.room_key.stream
possibly_changed = set(changed)
possibly_left = set()

View File

@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@ -167,7 +167,7 @@ class InitialSyncHandler(BaseHandler):
self.state_handler.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
room_end_token = RoomStreamToken(None, event.stream_ordering,)
deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id]
)

View File

@ -973,6 +973,7 @@ class EventCreationHandler:
This should only be run on the instance in charge of persisting events.
"""
assert self._is_event_writer
assert self.storage.persistence is not None
if ratelimit:
# We check if this is a room admin redacting an event so that we

View File

@ -344,7 +344,7 @@ class PaginationHandler:
# gets called.
raise Exception("limit not set")
room_token = RoomStreamToken.parse(from_token.room_key)
room_token = from_token.room_key
with await self.pagination_lock.read(room_id):
(
@ -381,7 +381,7 @@ class PaginationHandler:
if leave_token.topological < max_topo:
from_token = from_token.copy_and_replace(
"room_key", leave_token_str
"room_key", leave_token
)
await self.hs.get_handlers().federation_handler.maybe_backfill(

View File

@ -1091,20 +1091,19 @@ class RoomEventSource:
async def get_new_events(
self,
user: UserID,
from_key: str,
from_key: RoomStreamToken,
limit: int,
room_ids: List[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], str]:
) -> Tuple[List[EventBase], RoomStreamToken]:
# We just ignore the key for now.
to_key = self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
if from_key.topological:
logger.warning("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,)
from_key = RoomStreamToken(None, from_key.stream)
app_service = self.store.get_app_service_by_user_id(user.to_string())
if app_service:
@ -1133,14 +1132,14 @@ class RoomEventSource:
events[:] = events[:limit]
if events:
end_key = events[-1].internal_metadata.after
end_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
else:
end_key = to_key
return (events, end_key)
def get_current_key(self) -> str:
return "s%d" % (self.store.get_room_max_stream_ordering(),)
def get_current_key(self) -> RoomStreamToken:
return RoomStreamToken(None, self.store.get_room_max_stream_ordering())
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
return self.store.get_room_events_max_id(room_id)

View File

@ -378,7 +378,7 @@ class SyncHandler:
sync_config = sync_result_builder.sync_config
with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0"
typing_key = since_token.typing_key if since_token else 0
room_ids = sync_result_builder.joined_room_ids
@ -402,7 +402,7 @@ class SyncHandler:
event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0"
receipt_key = since_token.receipt_key if since_token else 0
receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = await receipt_source.get_new_events(
@ -533,7 +533,7 @@ class SyncHandler:
if len(recents) > timeline_limit:
limited = True
recents = recents[-timeline_limit:]
room_key = recents[0].internal_metadata.before
room_key = RoomStreamToken.parse(recents[0].internal_metadata.before)
prev_batch_token = now_token.copy_and_replace("room_key", room_key)
@ -1322,6 +1322,7 @@ class SyncHandler:
is_guest=sync_config.is_guest,
include_offline=include_offline,
)
assert presence_key
sync_result_builder.now_token = now_token.copy_and_replace(
"presence_key", presence_key
)
@ -1484,7 +1485,7 @@ class SyncHandler:
if rooms_changed:
return True
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
stream_id = since_token.room_key.stream
for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id):
return True
@ -1750,7 +1751,7 @@ class SyncHandler:
continue
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
"room_key", RoomStreamToken(None, event.stream_ordering)
)
room_entries.append(
RoomSyncResultBuilder(

View File

@ -25,6 +25,7 @@ from typing import (
Set,
Tuple,
TypeVar,
Union,
)
from prometheus_client import Counter
@ -41,7 +42,7 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.streams.config import PaginationConfig
from synapse.types import Collection, StreamToken, UserID
from synapse.types import Collection, RoomStreamToken, StreamToken, UserID
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
@ -111,7 +112,9 @@ class _NotifierUserStream:
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key: str, stream_id: int, time_now_ms: int):
def notify(
self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int,
):
"""Notify any listeners for this user of a new event from an
event source.
Args:
@ -294,7 +297,12 @@ class Notifier:
rooms.add(event.room_id)
if users or rooms:
self.on_new_event("room_key", max_room_stream_id, users=users, rooms=rooms)
self.on_new_event(
"room_key",
RoomStreamToken(None, max_room_stream_id),
users=users,
rooms=rooms,
)
self._on_updated_room_token(max_room_stream_id)
def _on_updated_room_token(self, max_room_stream_id: int):
@ -329,7 +337,7 @@ class Notifier:
def on_new_event(
self,
stream_key: str,
new_token: int,
new_token: Union[int, RoomStreamToken],
users: Collection[UserID] = [],
rooms: Collection[str] = [],
):

View File

@ -47,6 +47,9 @@ class Storage:
# interfaces.
self.main = stores.main
self.persistence = EventsPersistenceStorage(hs, stores)
self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores)
self.persistence = None
if stores.persist_events:
self.persistence = EventsPersistenceStorage(hs, stores)

View File

@ -213,7 +213,7 @@ class PersistEventsStore:
Returns:
Filtered event ids
"""
results = []
results = [] # type: List[str]
def _get_events_which_are_prevs_txn(txn, batch):
sql = """
@ -631,7 +631,9 @@ class PersistEventsStore:
)
@classmethod
def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
def _filter_events_and_contexts_for_duplicates(
cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
) -> List[Tuple[EventBase, EventContext]]:
"""Ensure that we don't have the same event twice.
Pick the earliest non-outlier if there is one, else the earliest one.
@ -641,7 +643,9 @@ class PersistEventsStore:
Returns:
list[(EventBase, EventContext)]: filtered list
"""
new_events_and_contexts = OrderedDict()
new_events_and_contexts = (
OrderedDict()
) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context:
@ -655,7 +659,12 @@ class PersistEventsStore:
new_events_and_contexts[event.event_id] = (event, context)
return list(new_events_and_contexts.values())
def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
def _update_room_depths_txn(
self,
txn,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
):
"""Update min_depth for each room
Args:
@ -664,7 +673,7 @@ class PersistEventsStore:
we are persisting
backfilled (bool): True if the events were backfilled
"""
depth_updates = {}
depth_updates = {} # type: Dict[str, int]
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@ -1436,7 +1445,7 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
events_by_room = {}
events_by_room = {} # type: Dict[str, List[EventBase]]
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)

View File

@ -310,11 +310,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_rooms(
self,
room_ids: Collection[str],
from_key: str,
to_key: str,
from_key: RoomStreamToken,
to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
) -> Dict[str, Tuple[List[EventBase], str]]:
) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
"""Get new room events in stream ordering since `from_key`.
Args:
@ -333,9 +333,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
- list of recent events in the room
- stream ordering key for the start of the chunk of events returned.
"""
from_id = RoomStreamToken.parse_stream_token(from_key).stream
room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
room_ids = self._events_stream_cache.get_entities_changed(
room_ids, from_key.stream
)
if not room_ids:
return {}
@ -364,16 +364,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
def get_rooms_that_changed(
self, room_ids: Collection[str], from_key: str
self, room_ids: Collection[str], from_key: RoomStreamToken
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
room_ids
from_key: The room_key portion of a StreamToken
"""
from_id = RoomStreamToken.parse_stream_token(from_key).stream
from_id = from_key.stream
return {
room_id
for room_id in room_ids
@ -383,11 +379,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_room(
self,
room_id: str,
from_key: str,
to_key: str,
from_key: RoomStreamToken,
to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
) -> Tuple[List[EventBase], str]:
) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get new room events in stream ordering since `from_key`.
Args:
@ -408,8 +404,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if from_key == to_key:
return [], from_key
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
from_id = from_key.stream
to_id = to_key.stream
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
@ -441,7 +437,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse()
if rows:
key = "s%d" % min(r.stream_ordering for r in rows)
key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
else:
# Assume we didn't get anything because there was nothing to
# get.
@ -450,10 +446,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(
self, user_id: str, from_key: str, to_key: str
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
from_id = from_key.stream
to_id = to_key.stream
if from_key == to_key:
return []
@ -491,8 +487,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
async def get_recent_events_for_room(
self, room_id: str, limit: int, end_token: str
) -> Tuple[List[EventBase], str]:
self, room_id: str, limit: int, end_token: RoomStreamToken
) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@ -518,8 +514,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
async def get_recent_event_ids_for_room(
self, room_id: str, limit: int, end_token: str
) -> Tuple[List[_EventDictReturn], str]:
self, room_id: str, limit: int, end_token: RoomStreamToken
) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@ -535,13 +531,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if limit == 0:
return [], end_token
parsed_end_token = RoomStreamToken.parse(end_token)
rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
from_token=parsed_end_token,
from_token=end_token,
limit=limit,
)
@ -619,17 +613,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
allow_none=allow_none,
)
async def get_stream_token_for_event(self, event_id: str) -> str:
async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event
Args:
event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A "s%d" stream token.
A stream token.
"""
stream_id = await self.get_stream_id_for_event(event_id)
return "s%d" % (stream_id,)
return RoomStreamToken(None, stream_id)
async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
@ -954,7 +948,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], str]:
) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@ -1054,17 +1048,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
return rows, str(next_token)
return rows, next_token
async def paginate_room_events(
self,
room_id: str,
from_key: str,
to_key: Optional[str] = None,
from_key: RoomStreamToken,
to_key: Optional[RoomStreamToken] = None,
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], str]:
) -> Tuple[List[EventBase], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@ -1083,17 +1077,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and `to_key`).
"""
parsed_from_key = RoomStreamToken.parse(from_key)
parsed_to_key = None
if to_key:
parsed_to_key = RoomStreamToken.parse(to_key)
rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
parsed_from_key,
parsed_to_key,
from_key,
to_key,
direction,
limit,
event_filter,

View File

@ -18,7 +18,7 @@
import itertools
import logging
from collections import deque, namedtuple
from typing import Iterable, List, Optional, Set, Tuple
from typing import Dict, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter, Histogram
@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
from synapse.types import StateMap
from synapse.types import Collection, StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@ -185,6 +185,8 @@ class EventsPersistenceStorage:
# store for now.
self.main_store = stores.main
self.state_store = stores.state
assert stores.persist_events
self.persist_events_store = stores.persist_events
self._clock = hs.get_clock()
@ -208,7 +210,7 @@ class EventsPersistenceStorage:
Returns:
the stream ordering of the latest persisted event
"""
partitioned = {}
partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
@ -305,7 +307,9 @@ class EventsPersistenceStorage:
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
events_by_room = {}
events_by_room = (
{}
) # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
@ -436,7 +440,7 @@ class EventsPersistenceStorage:
self,
room_id: str,
event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: List[str],
latest_event_ids: Collection[str],
):
"""Calculates the new forward extremities for a room given events to
persist.
@ -470,7 +474,7 @@ class EventsPersistenceStorage:
# Remove any events which are prev_events of any existing events.
existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
result
)
) # type: Collection[str]
result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev

View File

@ -425,7 +425,9 @@ class RoomStreamToken:
@attr.s(slots=True, frozen=True)
class StreamToken:
room_key = attr.ib(type=str)
room_key = attr.ib(
type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken)
)
presence_key = attr.ib(type=int)
typing_key = attr.ib(type=int)
receipt_key = attr.ib(type=int)
@ -445,21 +447,16 @@ class StreamToken:
while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key
keys.append("0")
return cls(keys[0], *(int(k) for k in keys[1:]))
return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:]))
except Exception:
raise SynapseError(400, "Invalid Token")
def to_string(self):
return self._SEPARATOR.join([str(k) for k in attr.astuple(self)])
return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)])
@property
def room_stream_id(self):
# TODO(markjh): Awful hack to work around hacks in the presence tests
# which assume that the keys are integers.
if type(self.room_key) is int:
return self.room_key
else:
return int(self.room_key[1:].split("-")[-1])
return self.room_key.stream
def is_after(self, other):
"""Does this token contain events that the other doesn't?"""
@ -475,7 +472,7 @@ class StreamToken:
or (int(other.groups_key) < int(self.groups_key))
)
def copy_and_advance(self, key, new_value):
def copy_and_advance(self, key, new_value) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.
"""
@ -491,7 +488,7 @@ class StreamToken:
else:
return self
def copy_and_replace(self, key, new_value):
def copy_and_replace(self, key, new_value) -> "StreamToken":
return attr.evolve(self, **{key: new_value})

View File

@ -71,7 +71,10 @@ async def inject_event(
"""
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
await hs.get_storage().persistence.persist_event(event, context)
persistence = hs.get_storage().persistence
assert persistence is not None
await persistence.persist_event(event, context)
return event