Pull out less state when handling gaps mk2 (#12852)
parent
1b338476af
commit
b83bc5fab5
|
@ -0,0 +1 @@
|
|||
Pull out less state when handling gaps in room DAG.
|
|
@ -274,7 +274,7 @@ class FederationEventHandler:
|
|||
affected=pdu.event_id,
|
||||
)
|
||||
|
||||
await self._process_received_pdu(origin, pdu, state=None)
|
||||
await self._process_received_pdu(origin, pdu, state_ids=None)
|
||||
|
||||
async def on_send_membership_event(
|
||||
self, origin: str, event: EventBase
|
||||
|
@ -463,7 +463,9 @@ class FederationEventHandler:
|
|||
with nested_logging_context(suffix=event.event_id):
|
||||
context = await self._state_handler.compute_event_context(
|
||||
event,
|
||||
old_state=state,
|
||||
state_ids_before_event={
|
||||
(e.type, e.state_key): e.event_id for e in state
|
||||
},
|
||||
partial_state=partial_state,
|
||||
)
|
||||
|
||||
|
@ -512,12 +514,12 @@ class FederationEventHandler:
|
|||
#
|
||||
# This is the same operation as we do when we receive a regular event
|
||||
# over federation.
|
||||
state = await self._resolve_state_at_missing_prevs(destination, event)
|
||||
state_ids = await self._resolve_state_at_missing_prevs(destination, event)
|
||||
|
||||
# build a new state group for it if need be
|
||||
context = await self._state_handler.compute_event_context(
|
||||
event,
|
||||
old_state=state,
|
||||
state_ids_before_event=state_ids,
|
||||
)
|
||||
if context.partial_state:
|
||||
# this can happen if some or all of the event's prev_events still have
|
||||
|
@ -767,11 +769,12 @@ class FederationEventHandler:
|
|||
return
|
||||
|
||||
try:
|
||||
state = await self._resolve_state_at_missing_prevs(origin, event)
|
||||
state_ids = await self._resolve_state_at_missing_prevs(origin, event)
|
||||
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
|
||||
# not return partial state
|
||||
|
||||
await self._process_received_pdu(
|
||||
origin, event, state=state, backfilled=backfilled
|
||||
origin, event, state_ids=state_ids, backfilled=backfilled
|
||||
)
|
||||
except FederationError as e:
|
||||
if e.code == 403:
|
||||
|
@ -781,7 +784,7 @@ class FederationEventHandler:
|
|||
|
||||
async def _resolve_state_at_missing_prevs(
|
||||
self, dest: str, event: EventBase
|
||||
) -> Optional[Iterable[EventBase]]:
|
||||
) -> Optional[StateMap[str]]:
|
||||
"""Calculate the state at an event with missing prev_events.
|
||||
|
||||
This is used when we have pulled a batch of events from a remote server, and
|
||||
|
@ -808,8 +811,8 @@ class FederationEventHandler:
|
|||
event: an event to check for missing prevs.
|
||||
|
||||
Returns:
|
||||
if we already had all the prev events, `None`. Otherwise, returns a list of
|
||||
the events in the state at `event`.
|
||||
if we already had all the prev events, `None`. Otherwise, returns
|
||||
the event ids of the state at `event`.
|
||||
"""
|
||||
room_id = event.room_id
|
||||
event_id = event.event_id
|
||||
|
@ -829,7 +832,7 @@ class FederationEventHandler:
|
|||
)
|
||||
# Calculate the state after each of the previous events, and
|
||||
# resolve them to find the correct state at the current event.
|
||||
event_map = {event_id: event}
|
||||
|
||||
try:
|
||||
# Get the state of the events we know about
|
||||
ours = await self._state_storage.get_state_groups_ids(room_id, seen)
|
||||
|
@ -849,40 +852,23 @@ class FederationEventHandler:
|
|||
# note that if any of the missing prevs share missing state or
|
||||
# auth events, the requests to fetch those events are deduped
|
||||
# by the get_pdu_cache in federation_client.
|
||||
remote_state = await self._get_state_after_missing_prev_event(
|
||||
dest, room_id, p
|
||||
remote_state_map = (
|
||||
await self._get_state_ids_after_missing_prev_event(
|
||||
dest, room_id, p
|
||||
)
|
||||
)
|
||||
|
||||
remote_state_map = {
|
||||
(x.type, x.state_key): x.event_id for x in remote_state
|
||||
}
|
||||
state_maps.append(remote_state_map)
|
||||
|
||||
for x in remote_state:
|
||||
event_map[x.event_id] = x
|
||||
|
||||
room_version = await self._store.get_room_version_id(room_id)
|
||||
state_map = await self._state_resolution_handler.resolve_events_with_store(
|
||||
room_id,
|
||||
room_version,
|
||||
state_maps,
|
||||
event_map,
|
||||
event_map={event_id: event},
|
||||
state_res_store=StateResolutionStore(self._store),
|
||||
)
|
||||
|
||||
# We need to give _process_received_pdu the actual state events
|
||||
# rather than event ids, so generate that now.
|
||||
|
||||
# First though we need to fetch all the events that are in
|
||||
# state_map, so we can build up the state below.
|
||||
evs = await self._store.get_events(
|
||||
list(state_map.values()),
|
||||
get_prev_content=False,
|
||||
redact_behaviour=EventRedactBehaviour.as_is,
|
||||
)
|
||||
event_map.update(evs)
|
||||
|
||||
state = [event_map[e] for e in state_map.values()]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error attempting to resolve state at missing prev_events",
|
||||
|
@ -894,14 +880,14 @@ class FederationEventHandler:
|
|||
"We can't get valid state history.",
|
||||
affected=event_id,
|
||||
)
|
||||
return state
|
||||
return state_map
|
||||
|
||||
async def _get_state_after_missing_prev_event(
|
||||
async def _get_state_ids_after_missing_prev_event(
|
||||
self,
|
||||
destination: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> List[EventBase]:
|
||||
) -> StateMap[str]:
|
||||
"""Requests all of the room state at a given event from a remote homeserver.
|
||||
|
||||
Args:
|
||||
|
@ -910,7 +896,7 @@ class FederationEventHandler:
|
|||
event_id: The id of the event we want the state at.
|
||||
|
||||
Returns:
|
||||
A list of events in the state, including the event itself
|
||||
The event ids of the state *after* the given event.
|
||||
"""
|
||||
(
|
||||
state_event_ids,
|
||||
|
@ -925,19 +911,17 @@ class FederationEventHandler:
|
|||
len(auth_event_ids),
|
||||
)
|
||||
|
||||
# start by just trying to fetch the events from the store
|
||||
# Start by checking events we already have in the DB
|
||||
desired_events = set(state_event_ids)
|
||||
desired_events.add(event_id)
|
||||
logger.debug("Fetching %i events from cache/store", len(desired_events))
|
||||
fetched_events = await self._store.get_events(
|
||||
desired_events, allow_rejected=True
|
||||
)
|
||||
have_events = await self._store.have_seen_events(room_id, desired_events)
|
||||
|
||||
missing_desired_events = desired_events - fetched_events.keys()
|
||||
missing_desired_events = desired_events - have_events
|
||||
logger.debug(
|
||||
"We are missing %i events (got %i)",
|
||||
len(missing_desired_events),
|
||||
len(fetched_events),
|
||||
len(have_events),
|
||||
)
|
||||
|
||||
# We probably won't need most of the auth events, so let's just check which
|
||||
|
@ -948,7 +932,7 @@ class FederationEventHandler:
|
|||
# already have a bunch of the state events. It would be nice if the
|
||||
# federation api gave us a way of finding out which we actually need.
|
||||
|
||||
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
|
||||
missing_auth_events = set(auth_event_ids) - have_events
|
||||
missing_auth_events.difference_update(
|
||||
await self._store.have_seen_events(room_id, missing_auth_events)
|
||||
)
|
||||
|
@ -974,47 +958,51 @@ class FederationEventHandler:
|
|||
destination=destination, room_id=room_id, event_ids=missing_events
|
||||
)
|
||||
|
||||
# we need to make sure we re-load from the database to get the rejected
|
||||
# state correct.
|
||||
fetched_events.update(
|
||||
await self._store.get_events(missing_desired_events, allow_rejected=True)
|
||||
)
|
||||
# We now need to fill out the state map, which involves fetching the
|
||||
# type and state key for each event ID in the state.
|
||||
state_map = {}
|
||||
|
||||
# check for events which were in the wrong room.
|
||||
#
|
||||
# this can happen if a remote server claims that the state or
|
||||
# auth_events at an event in room A are actually events in room B
|
||||
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
|
||||
for state_event_id, metadata in event_metadata.items():
|
||||
if metadata.room_id != room_id:
|
||||
# This is a bogus situation, but since we may only discover it a long time
|
||||
# after it happened, we try our best to carry on, by just omitting the
|
||||
# bad events from the returned state set.
|
||||
#
|
||||
# This can happen if a remote server claims that the state or
|
||||
# auth_events at an event in room A are actually events in room B
|
||||
logger.warning(
|
||||
"Remote server %s claims event %s in room %s is an auth/state "
|
||||
"event in room %s",
|
||||
destination,
|
||||
state_event_id,
|
||||
metadata.room_id,
|
||||
room_id,
|
||||
)
|
||||
continue
|
||||
|
||||
bad_events = [
|
||||
(event_id, event.room_id)
|
||||
for event_id, event in fetched_events.items()
|
||||
if event.room_id != room_id
|
||||
]
|
||||
if metadata.state_key is None:
|
||||
logger.warning(
|
||||
"Remote server gave us non-state event in state: %s", state_event_id
|
||||
)
|
||||
continue
|
||||
|
||||
for bad_event_id, bad_room_id in bad_events:
|
||||
# This is a bogus situation, but since we may only discover it a long time
|
||||
# after it happened, we try our best to carry on, by just omitting the
|
||||
# bad events from the returned state set.
|
||||
logger.warning(
|
||||
"Remote server %s claims event %s in room %s is an auth/state "
|
||||
"event in room %s",
|
||||
destination,
|
||||
bad_event_id,
|
||||
bad_room_id,
|
||||
room_id,
|
||||
)
|
||||
|
||||
del fetched_events[bad_event_id]
|
||||
state_map[(metadata.event_type, metadata.state_key)] = state_event_id
|
||||
|
||||
# if we couldn't get the prev event in question, that's a problem.
|
||||
remote_event = fetched_events.get(event_id)
|
||||
remote_event = await self._store.get_event(
|
||||
event_id,
|
||||
allow_none=True,
|
||||
allow_rejected=True,
|
||||
redact_behaviour=EventRedactBehaviour.as_is,
|
||||
)
|
||||
if not remote_event:
|
||||
raise Exception("Unable to get missing prev_event %s" % (event_id,))
|
||||
|
||||
# missing state at that event is a warning, not a blocker
|
||||
# XXX: this doesn't sound right? it means that we'll end up with incomplete
|
||||
# state.
|
||||
failed_to_fetch = desired_events - fetched_events.keys()
|
||||
failed_to_fetch = desired_events - event_metadata.keys()
|
||||
if failed_to_fetch:
|
||||
logger.warning(
|
||||
"Failed to fetch missing state events for %s %s",
|
||||
|
@ -1022,14 +1010,12 @@ class FederationEventHandler:
|
|||
failed_to_fetch,
|
||||
)
|
||||
|
||||
remote_state = [
|
||||
fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
|
||||
]
|
||||
|
||||
if remote_event.is_state() and remote_event.rejected_reason is None:
|
||||
remote_state.append(remote_event)
|
||||
state_map[
|
||||
(remote_event.type, remote_event.state_key)
|
||||
] = remote_event.event_id
|
||||
|
||||
return remote_state
|
||||
return state_map
|
||||
|
||||
async def _get_state_and_persist(
|
||||
self, destination: str, room_id: str, event_id: str
|
||||
|
@ -1056,7 +1042,7 @@ class FederationEventHandler:
|
|||
self,
|
||||
origin: str,
|
||||
event: EventBase,
|
||||
state: Optional[Iterable[EventBase]],
|
||||
state_ids: Optional[StateMap[str]],
|
||||
backfilled: bool = False,
|
||||
) -> None:
|
||||
"""Called when we have a new non-outlier event.
|
||||
|
@ -1078,7 +1064,7 @@ class FederationEventHandler:
|
|||
|
||||
event: event to be persisted
|
||||
|
||||
state: Normally None, but if we are handling a gap in the graph
|
||||
state_ids: Normally None, but if we are handling a gap in the graph
|
||||
(ie, we are missing one or more prev_events), the resolved state at the
|
||||
event
|
||||
|
||||
|
@ -1090,7 +1076,8 @@ class FederationEventHandler:
|
|||
|
||||
try:
|
||||
context = await self._state_handler.compute_event_context(
|
||||
event, old_state=state
|
||||
event,
|
||||
state_ids_before_event=state_ids,
|
||||
)
|
||||
context = await self._check_event_auth(
|
||||
origin,
|
||||
|
@ -1107,7 +1094,7 @@ class FederationEventHandler:
|
|||
# For new (non-backfilled and non-outlier) events we check if the event
|
||||
# passes auth based on the current state. If it doesn't then we
|
||||
# "soft-fail" the event.
|
||||
await self._check_for_soft_fail(event, state, origin=origin)
|
||||
await self._check_for_soft_fail(event, state_ids, origin=origin)
|
||||
|
||||
await self._run_push_actions_and_persist_event(event, context, backfilled)
|
||||
|
||||
|
@ -1589,7 +1576,7 @@ class FederationEventHandler:
|
|||
async def _check_for_soft_fail(
|
||||
self,
|
||||
event: EventBase,
|
||||
state: Optional[Iterable[EventBase]],
|
||||
state_ids: Optional[StateMap[str]],
|
||||
origin: str,
|
||||
) -> None:
|
||||
"""Checks if we should soft fail the event; if so, marks the event as
|
||||
|
@ -1597,7 +1584,7 @@ class FederationEventHandler:
|
|||
|
||||
Args:
|
||||
event
|
||||
state: The state at the event if we don't have all the event's prev events
|
||||
state_ids: The state at the event if we don't have all the event's prev events
|
||||
origin: The host the event originates from.
|
||||
"""
|
||||
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
|
||||
|
@ -1613,7 +1600,7 @@ class FederationEventHandler:
|
|||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
|
||||
# Calculate the "current state".
|
||||
if state is not None:
|
||||
if state_ids is not None:
|
||||
# If we're explicitly given the state then we won't have all the
|
||||
# prev events, and so we have a gap in the graph. In this case
|
||||
# we want to be a little careful as we might have been down for
|
||||
|
@ -1626,17 +1613,20 @@ class FederationEventHandler:
|
|||
# given state at the event. This should correctly handle cases
|
||||
# like bans, especially with state res v2.
|
||||
|
||||
state_sets_d = await self._state_storage.get_state_groups(
|
||||
state_sets_d = await self._state_storage.get_state_groups_ids(
|
||||
event.room_id, extrem_ids
|
||||
)
|
||||
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
|
||||
state_sets.append(state)
|
||||
current_states = await self._state_handler.resolve_events(
|
||||
room_version, state_sets, event
|
||||
state_sets: List[StateMap[str]] = list(state_sets_d.values())
|
||||
state_sets.append(state_ids)
|
||||
current_state_ids = (
|
||||
await self._state_resolution_handler.resolve_events_with_store(
|
||||
event.room_id,
|
||||
room_version,
|
||||
state_sets,
|
||||
event_map=None,
|
||||
state_res_store=StateResolutionStore(self._store),
|
||||
)
|
||||
)
|
||||
current_state_ids: StateMap[str] = {
|
||||
k: e.event_id for k, e in current_states.items()
|
||||
}
|
||||
else:
|
||||
current_state_ids = await self._state_handler.get_current_state_ids(
|
||||
event.room_id, latest_event_ids=extrem_ids
|
||||
|
|
|
@ -55,7 +55,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
|
||||
from synapse.types import (
|
||||
MutableStateMap,
|
||||
Requester,
|
||||
RoomAlias,
|
||||
StreamToken,
|
||||
UserID,
|
||||
create_requester,
|
||||
)
|
||||
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer, gather_results
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
@ -1022,8 +1029,35 @@ class EventCreationHandler:
|
|||
#
|
||||
# TODO(faster_joins): figure out how this works, and make sure that the
|
||||
# old state is complete.
|
||||
old_state = await self.store.get_events_as_list(state_event_ids)
|
||||
context = await self.state.compute_event_context(event, old_state=old_state)
|
||||
metadata = await self.store.get_metadata_for_events(state_event_ids)
|
||||
|
||||
state_map_for_event: MutableStateMap[str] = {}
|
||||
for state_id in state_event_ids:
|
||||
data = metadata.get(state_id)
|
||||
if data is None:
|
||||
# We're trying to persist a new historical batch of events
|
||||
# with the given state, e.g. via
|
||||
# `RoomBatchSendEventRestServlet`. The state can be inferred
|
||||
# by Synapse or set directly by the client.
|
||||
#
|
||||
# Either way, we should have persisted all the state before
|
||||
# getting here.
|
||||
raise Exception(
|
||||
f"State event {state_id} not found in DB,"
|
||||
" Synapse should have persisted it before using it."
|
||||
)
|
||||
|
||||
if data.state_key is None:
|
||||
raise Exception(
|
||||
f"Trying to set non-state event {state_id} as state"
|
||||
)
|
||||
|
||||
state_map_for_event[(data.event_type, data.state_key)] = state_id
|
||||
|
||||
context = await self.state.compute_event_context(
|
||||
event,
|
||||
state_ids_before_event=state_map_for_event,
|
||||
)
|
||||
else:
|
||||
context = await self.state.compute_event_context(event)
|
||||
|
||||
|
|
|
@ -261,7 +261,7 @@ class StateHandler:
|
|||
async def compute_event_context(
|
||||
self,
|
||||
event: EventBase,
|
||||
old_state: Optional[Iterable[EventBase]] = None,
|
||||
state_ids_before_event: Optional[StateMap[str]] = None,
|
||||
partial_state: bool = False,
|
||||
) -> EventContext:
|
||||
"""Build an EventContext structure for a non-outlier event.
|
||||
|
@ -273,12 +273,12 @@ class StateHandler:
|
|||
|
||||
Args:
|
||||
event:
|
||||
old_state: The state at the event if it can't be
|
||||
calculated from existing events. This is normally only specified
|
||||
when receiving an event from federation where we don't have the
|
||||
prev events for, e.g. when backfilling.
|
||||
partial_state: True if `old_state` is partial and omits non-critical
|
||||
membership events
|
||||
state_ids_before_event: The event ids of the state before the event if
|
||||
it can't be calculated from existing events. This is normally
|
||||
only specified when receiving an event from federation where we
|
||||
don't have the prev events, e.g. when backfilling.
|
||||
partial_state: True if `state_ids_before_event` is partial and omits
|
||||
non-critical membership events
|
||||
Returns:
|
||||
The event context.
|
||||
"""
|
||||
|
@ -286,13 +286,11 @@ class StateHandler:
|
|||
assert not event.internal_metadata.is_outlier()
|
||||
|
||||
#
|
||||
# first of all, figure out the state before the event
|
||||
# first of all, figure out the state before the event, unless we
|
||||
# already have it.
|
||||
#
|
||||
if old_state:
|
||||
if state_ids_before_event:
|
||||
# if we're given the state before the event, then we use that
|
||||
state_ids_before_event: StateMap[str] = {
|
||||
(s.type, s.state_key): s.event_id for s in old_state
|
||||
}
|
||||
state_group_before_event = None
|
||||
state_group_before_event_prev_group = None
|
||||
deltas_to_state_group_before_event = None
|
||||
|
|
|
@ -16,6 +16,8 @@ import collections.abc
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
|
@ -26,6 +28,7 @@ from synapse.storage.database import (
|
|||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
|
@ -33,6 +36,7 @@ from synapse.storage.state import StateFilter
|
|||
from synapse.types import JsonDict, JsonMapping, StateMap
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -43,6 +47,15 @@ logger = logging.getLogger(__name__)
|
|||
MAX_STATE_DELTA_HOPS = 100
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class EventMetadata:
|
||||
"""Returned by `get_metadata_for_events`"""
|
||||
|
||||
room_id: str
|
||||
event_type: str
|
||||
state_key: Optional[str]
|
||||
|
||||
|
||||
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
|
||||
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
|
||||
if not v:
|
||||
|
@ -133,6 +146,52 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return room_version
|
||||
|
||||
async def get_metadata_for_events(
|
||||
self, event_ids: Collection[str]
|
||||
) -> Dict[str, EventMetadata]:
|
||||
"""Get some metadata (room_id, type, state_key) for the given events.
|
||||
|
||||
This method is a faster alternative than fetching the full events from
|
||||
the DB, and should be used when the full event is not needed.
|
||||
|
||||
Returns metadata for rejected and redacted events. Events that have not
|
||||
been persisted are omitted from the returned dict.
|
||||
"""
|
||||
|
||||
def get_metadata_for_events_txn(
|
||||
txn: LoggingTransaction,
|
||||
batch_ids: Collection[str],
|
||||
) -> Dict[str, EventMetadata]:
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine, "e.event_id", batch_ids
|
||||
)
|
||||
|
||||
sql = f"""
|
||||
SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
|
||||
LEFT JOIN state_events USING (event_id)
|
||||
WHERE {clause}
|
||||
"""
|
||||
|
||||
txn.execute(sql, args)
|
||||
return {
|
||||
event_id: EventMetadata(
|
||||
room_id=room_id, event_type=event_type, state_key=state_key
|
||||
)
|
||||
for event_id, room_id, event_type, state_key in txn
|
||||
}
|
||||
|
||||
result_map: Dict[str, EventMetadata] = {}
|
||||
for batch_ids in batch_iter(event_ids, 1000):
|
||||
result_map.update(
|
||||
await self.db_pool.runInteraction(
|
||||
"get_metadata_for_events",
|
||||
get_metadata_for_events_txn,
|
||||
batch_ids=batch_ids,
|
||||
)
|
||||
)
|
||||
|
||||
return result_map
|
||||
|
||||
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
|
||||
"""Get the predecessor of an upgraded room if it exists.
|
||||
Otherwise return None.
|
||||
|
|
|
@ -276,7 +276,11 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
|||
# federation handler wanting to backfill the fake event.
|
||||
self.get_success(
|
||||
federation_event_handler._process_received_pdu(
|
||||
self.OTHER_SERVER_NAME, event, state=current_state
|
||||
self.OTHER_SERVER_NAME,
|
||||
event,
|
||||
state_ids={
|
||||
(e.type, e.state_key): e.event_id for e in current_state
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
def persist_event(self, event, state=None):
|
||||
"""Persist the event, with optional state"""
|
||||
context = self.get_success(
|
||||
self.state.compute_event_context(event, old_state=state)
|
||||
self.state.compute_event_context(event, state_ids_before_event=state)
|
||||
)
|
||||
self.get_success(self.persistence.persist_event(event, context))
|
||||
|
||||
|
@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.state.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([remote_event_2.event_id])
|
||||
|
@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
# setting. The state resolution across the old and new event will then
|
||||
# include it, and so the resolved state won't match the new state.
|
||||
state_before_gap = dict(
|
||||
self.get_success(self.state.get_current_state(self.room_id))
|
||||
self.get_success(self.state.get_current_state_ids(self.room_id))
|
||||
)
|
||||
state_before_gap.pop(("m.room.history_visibility", ""))
|
||||
|
||||
context = self.get_success(
|
||||
self.state.compute_event_context(
|
||||
remote_event_2, old_state=state_before_gap.values()
|
||||
remote_event_2,
|
||||
state_ids_before_event=state_before_gap,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.state.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([remote_event_2.event_id])
|
||||
|
@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.state.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
|
||||
|
@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.state.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([remote_event_2.event_id])
|
||||
|
@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.state.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
|
||||
|
@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.state.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([local_message_event_id, remote_event_2.event_id])
|
||||
|
|
|
@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase):
|
|||
]
|
||||
|
||||
context = yield defer.ensureDeferred(
|
||||
self.state.compute_event_context(event, old_state=old_state)
|
||||
self.state.compute_event_context(
|
||||
event,
|
||||
state_ids_before_event={
|
||||
(e.type, e.state_key): e.event_id for e in old_state
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||
|
@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase):
|
|||
]
|
||||
|
||||
context = yield defer.ensureDeferred(
|
||||
self.state.compute_event_context(event, old_state=old_state)
|
||||
self.state.compute_event_context(
|
||||
event,
|
||||
state_ids_before_event={
|
||||
(e.type, e.state_key): e.event_id for e in old_state
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||
|
|
Loading…
Reference in New Issue