Faster room joins: fix race in recalculation of current room state (#13151)
Bounce recalculation of current state to the correct event persister and move recalculation of current state into the event persistence queue, to avoid concurrent updates to a room's current state. Also give recalculation of a room's current state a real stream ordering. Signed-off-by: Sean Quah <seanq@matrix.org>pull/13211/head
parent
2b5ab8e367
commit
1391a76cd2
|
@ -0,0 +1 @@
|
|||
Faster room joins: fix race in recalculation of current room state.
|
|
@ -1559,14 +1559,9 @@ class FederationHandler:
|
|||
# all the events are updated, so we can update current state and
|
||||
# clear the lazy-loading flag.
|
||||
logger.info("Updating current state for %s", room_id)
|
||||
# TODO(faster_joins): support workers
|
||||
# TODO(faster_joins): notify workers in notify_room_un_partial_stated
|
||||
# https://github.com/matrix-org/synapse/issues/12994
|
||||
assert (
|
||||
self._storage_controllers.persistence is not None
|
||||
), "worker-mode deployments not currently supported here"
|
||||
await self._storage_controllers.persistence.update_current_state(
|
||||
room_id
|
||||
)
|
||||
await self.state_handler.update_current_state(room_id)
|
||||
|
||||
logger.info("Clearing partial-state flag for %s", room_id)
|
||||
success = await self.store.clear_partial_state_room(room_id)
|
||||
|
|
|
@ -25,6 +25,7 @@ from synapse.replication.http import (
|
|||
push,
|
||||
register,
|
||||
send_event,
|
||||
state,
|
||||
streams,
|
||||
)
|
||||
|
||||
|
@ -48,6 +49,7 @@ class ReplicationRestResource(JsonResource):
|
|||
streams.register_servlets(hs, self)
|
||||
account_data.register_servlets(hs, self)
|
||||
push.register_servlets(hs, self)
|
||||
state.register_servlets(hs, self)
|
||||
|
||||
# The following can't currently be instantiated on workers.
|
||||
if hs.config.worker.worker_app is None:
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplicationUpdateCurrentStateRestServlet(ReplicationEndpoint):
|
||||
"""Recalculates the current state for a room, and persists it.
|
||||
|
||||
The API looks like:
|
||||
|
||||
POST /_synapse/replication/update_current_state/:room_id
|
||||
|
||||
{}
|
||||
|
||||
200 OK
|
||||
|
||||
{}
|
||||
"""
|
||||
|
||||
NAME = "update_current_state"
|
||||
PATH_ARGS = ("room_id",)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self._state_handler = hs.get_state_handler()
|
||||
self._events_shard_config = hs.config.worker.events_shard_config
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override]
|
||||
return {}
|
||||
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
writer_instance = self._events_shard_config.get_instance(room_id)
|
||||
if writer_instance != self._instance_name:
|
||||
raise SynapseError(
|
||||
400, "/update_current_state request was routed to the wrong worker"
|
||||
)
|
||||
|
||||
await self._state_handler.update_current_state(room_id)
|
||||
|
||||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if hs.get_instance_name() in hs.config.worker.writers.events:
|
||||
ReplicationUpdateCurrentStateRestServlet(hs).register(http_server)
|
|
@ -43,6 +43,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersio
|
|||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.logging.context import ContextResourceUsage
|
||||
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
|
||||
from synapse.state import v1, v2
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
|
@ -129,6 +130,12 @@ class StateHandler:
|
|||
self.hs = hs
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._events_shard_config = hs.config.worker.events_shard_config
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self._update_current_state_client = (
|
||||
ReplicationUpdateCurrentStateRestServlet.make_client(hs)
|
||||
)
|
||||
|
||||
async def get_current_state_ids(
|
||||
self,
|
||||
|
@ -423,6 +430,24 @@ class StateHandler:
|
|||
|
||||
return {key: state_map[ev_id] for key, ev_id in new_state.items()}
|
||||
|
||||
async def update_current_state(self, room_id: str) -> None:
|
||||
"""Recalculates the current state for a room, and persists it.
|
||||
|
||||
Raises:
|
||||
SynapseError(502): if all attempts to connect to the event persister worker
|
||||
fail
|
||||
"""
|
||||
writer_instance = self._events_shard_config.get_instance(room_id)
|
||||
if writer_instance != self._instance_name:
|
||||
await self._update_current_state_client(
|
||||
instance_name=writer_instance,
|
||||
room_id=room_id,
|
||||
)
|
||||
return
|
||||
|
||||
assert self._storage_controllers.persistence is not None
|
||||
await self._storage_controllers.persistence.update_current_state(room_id)
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class _StateResMetrics:
|
||||
|
|
|
@ -22,6 +22,7 @@ from typing import (
|
|||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Collection,
|
||||
Deque,
|
||||
Dict,
|
||||
|
@ -33,6 +34,7 @@ from typing import (
|
|||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
@ -111,9 +113,43 @@ times_pruned_extremities = Counter(
|
|||
|
||||
|
||||
@attr.s(auto_attribs=True, slots=True)
|
||||
class _EventPersistQueueItem:
|
||||
class _PersistEventsTask:
|
||||
"""A batch of events to persist."""
|
||||
|
||||
name: ClassVar[str] = "persist_event_batch" # used for opentracing
|
||||
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]]
|
||||
backfilled: bool
|
||||
|
||||
def try_merge(self, task: "_EventPersistQueueTask") -> bool:
|
||||
"""Batches events with the same backfilled option together."""
|
||||
if (
|
||||
not isinstance(task, _PersistEventsTask)
|
||||
or self.backfilled != task.backfilled
|
||||
):
|
||||
return False
|
||||
|
||||
self.events_and_contexts.extend(task.events_and_contexts)
|
||||
return True
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, slots=True)
|
||||
class _UpdateCurrentStateTask:
|
||||
"""A room whose current state needs recalculating."""
|
||||
|
||||
name: ClassVar[str] = "update_current_state" # used for opentracing
|
||||
|
||||
def try_merge(self, task: "_EventPersistQueueTask") -> bool:
|
||||
"""Deduplicates consecutive recalculations of current state."""
|
||||
return isinstance(task, _UpdateCurrentStateTask)
|
||||
|
||||
|
||||
_EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask]
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, slots=True)
|
||||
class _EventPersistQueueItem:
|
||||
task: _EventPersistQueueTask
|
||||
deferred: ObservableDeferred
|
||||
|
||||
parent_opentracing_span_contexts: List = attr.ib(factory=list)
|
||||
|
@ -127,14 +163,16 @@ _PersistResult = TypeVar("_PersistResult")
|
|||
|
||||
|
||||
class _EventPeristenceQueue(Generic[_PersistResult]):
|
||||
"""Queues up events so that they can be persisted in bulk with only one
|
||||
concurrent transaction per room.
|
||||
"""Queues up tasks so that they can be processed with only one concurrent
|
||||
transaction per room.
|
||||
|
||||
Tasks can be bulk persistence of events or recalculation of a room's current state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
per_item_callback: Callable[
|
||||
[List[Tuple[EventBase, EventContext]], bool],
|
||||
[str, _EventPersistQueueTask],
|
||||
Awaitable[_PersistResult],
|
||||
],
|
||||
):
|
||||
|
@ -150,18 +188,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
|||
async def add_to_queue(
|
||||
self,
|
||||
room_id: str,
|
||||
events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool,
|
||||
task: _EventPersistQueueTask,
|
||||
) -> _PersistResult:
|
||||
"""Add events to the queue, with the given persist_event options.
|
||||
"""Add a task to the queue.
|
||||
|
||||
If we are not already processing events in this room, starts off a background
|
||||
If we are not already processing tasks in this room, starts off a background
|
||||
process to to so, calling the per_item_callback for each item.
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
events_and_contexts (list[(EventBase, EventContext)]):
|
||||
backfilled (bool):
|
||||
task (_EventPersistQueueTask): A _PersistEventsTask or
|
||||
_UpdateCurrentStateTask to process.
|
||||
|
||||
Returns:
|
||||
the result returned by the `_per_item_callback` passed to
|
||||
|
@ -169,26 +206,20 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
|||
"""
|
||||
queue = self._event_persist_queues.setdefault(room_id, deque())
|
||||
|
||||
# if the last item in the queue has the same `backfilled` setting,
|
||||
# we can just add these new events to that item.
|
||||
if queue and queue[-1].backfilled == backfilled:
|
||||
if queue and queue[-1].task.try_merge(task):
|
||||
# the new task has been merged into the last task in the queue
|
||||
end_item = queue[-1]
|
||||
else:
|
||||
# need to make a new queue item
|
||||
deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
|
||||
defer.Deferred(), consumeErrors=True
|
||||
)
|
||||
|
||||
end_item = _EventPersistQueueItem(
|
||||
events_and_contexts=[],
|
||||
backfilled=backfilled,
|
||||
task=task,
|
||||
deferred=deferred,
|
||||
)
|
||||
queue.append(end_item)
|
||||
|
||||
# add our events to the queue item
|
||||
end_item.events_and_contexts.extend(events_and_contexts)
|
||||
|
||||
# also add our active opentracing span to the item so that we get a link back
|
||||
span = opentracing.active_span()
|
||||
if span:
|
||||
|
@ -202,7 +233,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
|||
|
||||
# add another opentracing span which links to the persist trace.
|
||||
with opentracing.start_active_span_follows_from(
|
||||
"persist_event_batch_complete", (end_item.opentracing_span_context,)
|
||||
f"{task.name}_complete", (end_item.opentracing_span_context,)
|
||||
):
|
||||
pass
|
||||
|
||||
|
@ -234,16 +265,14 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
|||
for item in queue:
|
||||
try:
|
||||
with opentracing.start_active_span_follows_from(
|
||||
"persist_event_batch",
|
||||
item.task.name,
|
||||
item.parent_opentracing_span_contexts,
|
||||
inherit_force_tracing=True,
|
||||
) as scope:
|
||||
if scope:
|
||||
item.opentracing_span_context = scope.span.context
|
||||
|
||||
ret = await self._per_item_callback(
|
||||
item.events_and_contexts, item.backfilled
|
||||
)
|
||||
ret = await self._per_item_callback(room_id, item.task)
|
||||
except Exception:
|
||||
with PreserveLoggingContext():
|
||||
item.deferred.errback()
|
||||
|
@ -292,9 +321,32 @@ class EventsPersistenceStorageController:
|
|||
self._clock = hs.get_clock()
|
||||
self._instance_name = hs.get_instance_name()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch)
|
||||
self._event_persist_queue = _EventPeristenceQueue(
|
||||
self._process_event_persist_queue_task
|
||||
)
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
|
||||
async def _process_event_persist_queue_task(
|
||||
self,
|
||||
room_id: str,
|
||||
task: _EventPersistQueueTask,
|
||||
) -> Dict[str, str]:
|
||||
"""Callback for the _event_persist_queue
|
||||
|
||||
Returns:
|
||||
A dictionary of event ID to event ID we didn't persist as we already
|
||||
had another event persisted with the same TXN ID.
|
||||
"""
|
||||
if isinstance(task, _PersistEventsTask):
|
||||
return await self._persist_event_batch(room_id, task)
|
||||
elif isinstance(task, _UpdateCurrentStateTask):
|
||||
await self._update_current_state(room_id, task)
|
||||
return {}
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"Found an unexpected task type in event persistence queue: {task}"
|
||||
)
|
||||
|
||||
@opentracing.trace
|
||||
async def persist_events(
|
||||
self,
|
||||
|
@ -329,7 +381,8 @@ class EventsPersistenceStorageController:
|
|||
) -> Dict[str, str]:
|
||||
room_id, evs_ctxs = item
|
||||
return await self._event_persist_queue.add_to_queue(
|
||||
room_id, evs_ctxs, backfilled=backfilled
|
||||
room_id,
|
||||
_PersistEventsTask(events_and_contexts=evs_ctxs, backfilled=backfilled),
|
||||
)
|
||||
|
||||
ret_vals = await yieldable_gather_results(enqueue, partitioned.items())
|
||||
|
@ -376,7 +429,10 @@ class EventsPersistenceStorageController:
|
|||
# event was deduplicated. (The dict may also include other entries if
|
||||
# the event was persisted in a batch with other events.)
|
||||
replaced_events = await self._event_persist_queue.add_to_queue(
|
||||
event.room_id, [(event, context)], backfilled=backfilled
|
||||
event.room_id,
|
||||
_PersistEventsTask(
|
||||
events_and_contexts=[(event, context)], backfilled=backfilled
|
||||
),
|
||||
)
|
||||
replaced_event = replaced_events.get(event.event_id)
|
||||
if replaced_event:
|
||||
|
@ -391,20 +447,22 @@ class EventsPersistenceStorageController:
|
|||
|
||||
async def update_current_state(self, room_id: str) -> None:
|
||||
"""Recalculate the current state for a room, and persist it"""
|
||||
await self._event_persist_queue.add_to_queue(
|
||||
room_id,
|
||||
_UpdateCurrentStateTask(),
|
||||
)
|
||||
|
||||
async def _update_current_state(
|
||||
self, room_id: str, _task: _UpdateCurrentStateTask
|
||||
) -> None:
|
||||
"""Callback for the _event_persist_queue
|
||||
|
||||
Recalculates the current state for a room, and persists it.
|
||||
"""
|
||||
state = await self._calculate_current_state(room_id)
|
||||
delta = await self._calculate_state_delta(room_id, state)
|
||||
|
||||
# TODO(faster_joins): get a real stream ordering, to make this work correctly
|
||||
# across workers.
|
||||
# https://github.com/matrix-org/synapse/issues/12994
|
||||
#
|
||||
# TODO(faster_joins): this can race against event persistence, in which case we
|
||||
# will end up with incorrect state. Perhaps we should make this a job we
|
||||
# farm out to the event persister thread, somehow.
|
||||
# https://github.com/matrix-org/synapse/issues/13007
|
||||
#
|
||||
stream_id = self.main_store.get_room_max_stream_ordering()
|
||||
await self.persist_events_store.update_current_state(room_id, delta, stream_id)
|
||||
await self.persist_events_store.update_current_state(room_id, delta)
|
||||
|
||||
async def _calculate_current_state(self, room_id: str) -> StateMap[str]:
|
||||
"""Calculate the current state of a room, based on the forward extremities
|
||||
|
@ -449,9 +507,7 @@ class EventsPersistenceStorageController:
|
|||
return res.state
|
||||
|
||||
async def _persist_event_batch(
|
||||
self,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool = False,
|
||||
self, _room_id: str, task: _PersistEventsTask
|
||||
) -> Dict[str, str]:
|
||||
"""Callback for the _event_persist_queue
|
||||
|
||||
|
@ -466,6 +522,9 @@ class EventsPersistenceStorageController:
|
|||
PartialStateConflictError: if attempting to persist a partial state event in
|
||||
a room that has been un-partial stated.
|
||||
"""
|
||||
events_and_contexts = task.events_and_contexts
|
||||
backfilled = task.backfilled
|
||||
|
||||
replaced_events: Dict[str, str] = {}
|
||||
if not events_and_contexts:
|
||||
return replaced_events
|
||||
|
|
|
@ -1007,16 +1007,16 @@ class PersistEventsStore:
|
|||
self,
|
||||
room_id: str,
|
||||
state_delta: DeltaState,
|
||||
stream_id: int,
|
||||
) -> None:
|
||||
"""Update the current state stored in the datatabase for the given room"""
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"update_current_state",
|
||||
self._update_current_state_txn,
|
||||
state_delta_by_room={room_id: state_delta},
|
||||
stream_id=stream_id,
|
||||
)
|
||||
async with self._stream_id_gen.get_next() as stream_ordering:
|
||||
await self.db_pool.runInteraction(
|
||||
"update_current_state",
|
||||
self._update_current_state_txn,
|
||||
state_delta_by_room={room_id: state_delta},
|
||||
stream_id=stream_ordering,
|
||||
)
|
||||
|
||||
def _update_current_state_txn(
|
||||
self,
|
||||
|
|
|
@ -195,6 +195,8 @@ class StateTestCase(unittest.TestCase):
|
|||
"get_state_resolution_handler",
|
||||
"get_account_validity_handler",
|
||||
"get_macaroon_generator",
|
||||
"get_instance_name",
|
||||
"get_simple_http_client",
|
||||
"hostname",
|
||||
]
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue