Pull out less state when handling gaps mk2 (#12852)
parent
1b338476af
commit
b83bc5fab5
|
@ -0,0 +1 @@
|
||||||
|
Pull out less state when handling gaps in room DAG.
|
|
@ -274,7 +274,7 @@ class FederationEventHandler:
|
||||||
affected=pdu.event_id,
|
affected=pdu.event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._process_received_pdu(origin, pdu, state=None)
|
await self._process_received_pdu(origin, pdu, state_ids=None)
|
||||||
|
|
||||||
async def on_send_membership_event(
|
async def on_send_membership_event(
|
||||||
self, origin: str, event: EventBase
|
self, origin: str, event: EventBase
|
||||||
|
@ -463,7 +463,9 @@ class FederationEventHandler:
|
||||||
with nested_logging_context(suffix=event.event_id):
|
with nested_logging_context(suffix=event.event_id):
|
||||||
context = await self._state_handler.compute_event_context(
|
context = await self._state_handler.compute_event_context(
|
||||||
event,
|
event,
|
||||||
old_state=state,
|
state_ids_before_event={
|
||||||
|
(e.type, e.state_key): e.event_id for e in state
|
||||||
|
},
|
||||||
partial_state=partial_state,
|
partial_state=partial_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -512,12 +514,12 @@ class FederationEventHandler:
|
||||||
#
|
#
|
||||||
# This is the same operation as we do when we receive a regular event
|
# This is the same operation as we do when we receive a regular event
|
||||||
# over federation.
|
# over federation.
|
||||||
state = await self._resolve_state_at_missing_prevs(destination, event)
|
state_ids = await self._resolve_state_at_missing_prevs(destination, event)
|
||||||
|
|
||||||
# build a new state group for it if need be
|
# build a new state group for it if need be
|
||||||
context = await self._state_handler.compute_event_context(
|
context = await self._state_handler.compute_event_context(
|
||||||
event,
|
event,
|
||||||
old_state=state,
|
state_ids_before_event=state_ids,
|
||||||
)
|
)
|
||||||
if context.partial_state:
|
if context.partial_state:
|
||||||
# this can happen if some or all of the event's prev_events still have
|
# this can happen if some or all of the event's prev_events still have
|
||||||
|
@ -767,11 +769,12 @@ class FederationEventHandler:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state = await self._resolve_state_at_missing_prevs(origin, event)
|
state_ids = await self._resolve_state_at_missing_prevs(origin, event)
|
||||||
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
|
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
|
||||||
# not return partial state
|
# not return partial state
|
||||||
|
|
||||||
await self._process_received_pdu(
|
await self._process_received_pdu(
|
||||||
origin, event, state=state, backfilled=backfilled
|
origin, event, state_ids=state_ids, backfilled=backfilled
|
||||||
)
|
)
|
||||||
except FederationError as e:
|
except FederationError as e:
|
||||||
if e.code == 403:
|
if e.code == 403:
|
||||||
|
@ -781,7 +784,7 @@ class FederationEventHandler:
|
||||||
|
|
||||||
async def _resolve_state_at_missing_prevs(
|
async def _resolve_state_at_missing_prevs(
|
||||||
self, dest: str, event: EventBase
|
self, dest: str, event: EventBase
|
||||||
) -> Optional[Iterable[EventBase]]:
|
) -> Optional[StateMap[str]]:
|
||||||
"""Calculate the state at an event with missing prev_events.
|
"""Calculate the state at an event with missing prev_events.
|
||||||
|
|
||||||
This is used when we have pulled a batch of events from a remote server, and
|
This is used when we have pulled a batch of events from a remote server, and
|
||||||
|
@ -808,8 +811,8 @@ class FederationEventHandler:
|
||||||
event: an event to check for missing prevs.
|
event: an event to check for missing prevs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
if we already had all the prev events, `None`. Otherwise, returns a list of
|
if we already had all the prev events, `None`. Otherwise, returns
|
||||||
the events in the state at `event`.
|
the event ids of the state at `event`.
|
||||||
"""
|
"""
|
||||||
room_id = event.room_id
|
room_id = event.room_id
|
||||||
event_id = event.event_id
|
event_id = event.event_id
|
||||||
|
@ -829,7 +832,7 @@ class FederationEventHandler:
|
||||||
)
|
)
|
||||||
# Calculate the state after each of the previous events, and
|
# Calculate the state after each of the previous events, and
|
||||||
# resolve them to find the correct state at the current event.
|
# resolve them to find the correct state at the current event.
|
||||||
event_map = {event_id: event}
|
|
||||||
try:
|
try:
|
||||||
# Get the state of the events we know about
|
# Get the state of the events we know about
|
||||||
ours = await self._state_storage.get_state_groups_ids(room_id, seen)
|
ours = await self._state_storage.get_state_groups_ids(room_id, seen)
|
||||||
|
@ -849,40 +852,23 @@ class FederationEventHandler:
|
||||||
# note that if any of the missing prevs share missing state or
|
# note that if any of the missing prevs share missing state or
|
||||||
# auth events, the requests to fetch those events are deduped
|
# auth events, the requests to fetch those events are deduped
|
||||||
# by the get_pdu_cache in federation_client.
|
# by the get_pdu_cache in federation_client.
|
||||||
remote_state = await self._get_state_after_missing_prev_event(
|
remote_state_map = (
|
||||||
dest, room_id, p
|
await self._get_state_ids_after_missing_prev_event(
|
||||||
|
dest, room_id, p
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
remote_state_map = {
|
|
||||||
(x.type, x.state_key): x.event_id for x in remote_state
|
|
||||||
}
|
|
||||||
state_maps.append(remote_state_map)
|
state_maps.append(remote_state_map)
|
||||||
|
|
||||||
for x in remote_state:
|
|
||||||
event_map[x.event_id] = x
|
|
||||||
|
|
||||||
room_version = await self._store.get_room_version_id(room_id)
|
room_version = await self._store.get_room_version_id(room_id)
|
||||||
state_map = await self._state_resolution_handler.resolve_events_with_store(
|
state_map = await self._state_resolution_handler.resolve_events_with_store(
|
||||||
room_id,
|
room_id,
|
||||||
room_version,
|
room_version,
|
||||||
state_maps,
|
state_maps,
|
||||||
event_map,
|
event_map={event_id: event},
|
||||||
state_res_store=StateResolutionStore(self._store),
|
state_res_store=StateResolutionStore(self._store),
|
||||||
)
|
)
|
||||||
|
|
||||||
# We need to give _process_received_pdu the actual state events
|
|
||||||
# rather than event ids, so generate that now.
|
|
||||||
|
|
||||||
# First though we need to fetch all the events that are in
|
|
||||||
# state_map, so we can build up the state below.
|
|
||||||
evs = await self._store.get_events(
|
|
||||||
list(state_map.values()),
|
|
||||||
get_prev_content=False,
|
|
||||||
redact_behaviour=EventRedactBehaviour.as_is,
|
|
||||||
)
|
|
||||||
event_map.update(evs)
|
|
||||||
|
|
||||||
state = [event_map[e] for e in state_map.values()]
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Error attempting to resolve state at missing prev_events",
|
"Error attempting to resolve state at missing prev_events",
|
||||||
|
@ -894,14 +880,14 @@ class FederationEventHandler:
|
||||||
"We can't get valid state history.",
|
"We can't get valid state history.",
|
||||||
affected=event_id,
|
affected=event_id,
|
||||||
)
|
)
|
||||||
return state
|
return state_map
|
||||||
|
|
||||||
async def _get_state_after_missing_prev_event(
|
async def _get_state_ids_after_missing_prev_event(
|
||||||
self,
|
self,
|
||||||
destination: str,
|
destination: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
) -> List[EventBase]:
|
) -> StateMap[str]:
|
||||||
"""Requests all of the room state at a given event from a remote homeserver.
|
"""Requests all of the room state at a given event from a remote homeserver.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -910,7 +896,7 @@ class FederationEventHandler:
|
||||||
event_id: The id of the event we want the state at.
|
event_id: The id of the event we want the state at.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of events in the state, including the event itself
|
The event ids of the state *after* the given event.
|
||||||
"""
|
"""
|
||||||
(
|
(
|
||||||
state_event_ids,
|
state_event_ids,
|
||||||
|
@ -925,19 +911,17 @@ class FederationEventHandler:
|
||||||
len(auth_event_ids),
|
len(auth_event_ids),
|
||||||
)
|
)
|
||||||
|
|
||||||
# start by just trying to fetch the events from the store
|
# Start by checking events we already have in the DB
|
||||||
desired_events = set(state_event_ids)
|
desired_events = set(state_event_ids)
|
||||||
desired_events.add(event_id)
|
desired_events.add(event_id)
|
||||||
logger.debug("Fetching %i events from cache/store", len(desired_events))
|
logger.debug("Fetching %i events from cache/store", len(desired_events))
|
||||||
fetched_events = await self._store.get_events(
|
have_events = await self._store.have_seen_events(room_id, desired_events)
|
||||||
desired_events, allow_rejected=True
|
|
||||||
)
|
|
||||||
|
|
||||||
missing_desired_events = desired_events - fetched_events.keys()
|
missing_desired_events = desired_events - have_events
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"We are missing %i events (got %i)",
|
"We are missing %i events (got %i)",
|
||||||
len(missing_desired_events),
|
len(missing_desired_events),
|
||||||
len(fetched_events),
|
len(have_events),
|
||||||
)
|
)
|
||||||
|
|
||||||
# We probably won't need most of the auth events, so let's just check which
|
# We probably won't need most of the auth events, so let's just check which
|
||||||
|
@ -948,7 +932,7 @@ class FederationEventHandler:
|
||||||
# already have a bunch of the state events. It would be nice if the
|
# already have a bunch of the state events. It would be nice if the
|
||||||
# federation api gave us a way of finding out which we actually need.
|
# federation api gave us a way of finding out which we actually need.
|
||||||
|
|
||||||
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
|
missing_auth_events = set(auth_event_ids) - have_events
|
||||||
missing_auth_events.difference_update(
|
missing_auth_events.difference_update(
|
||||||
await self._store.have_seen_events(room_id, missing_auth_events)
|
await self._store.have_seen_events(room_id, missing_auth_events)
|
||||||
)
|
)
|
||||||
|
@ -974,47 +958,51 @@ class FederationEventHandler:
|
||||||
destination=destination, room_id=room_id, event_ids=missing_events
|
destination=destination, room_id=room_id, event_ids=missing_events
|
||||||
)
|
)
|
||||||
|
|
||||||
# we need to make sure we re-load from the database to get the rejected
|
# We now need to fill out the state map, which involves fetching the
|
||||||
# state correct.
|
# type and state key for each event ID in the state.
|
||||||
fetched_events.update(
|
state_map = {}
|
||||||
await self._store.get_events(missing_desired_events, allow_rejected=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
# check for events which were in the wrong room.
|
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
|
||||||
#
|
for state_event_id, metadata in event_metadata.items():
|
||||||
# this can happen if a remote server claims that the state or
|
if metadata.room_id != room_id:
|
||||||
# auth_events at an event in room A are actually events in room B
|
# This is a bogus situation, but since we may only discover it a long time
|
||||||
|
# after it happened, we try our best to carry on, by just omitting the
|
||||||
|
# bad events from the returned state set.
|
||||||
|
#
|
||||||
|
# This can happen if a remote server claims that the state or
|
||||||
|
# auth_events at an event in room A are actually events in room B
|
||||||
|
logger.warning(
|
||||||
|
"Remote server %s claims event %s in room %s is an auth/state "
|
||||||
|
"event in room %s",
|
||||||
|
destination,
|
||||||
|
state_event_id,
|
||||||
|
metadata.room_id,
|
||||||
|
room_id,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
bad_events = [
|
if metadata.state_key is None:
|
||||||
(event_id, event.room_id)
|
logger.warning(
|
||||||
for event_id, event in fetched_events.items()
|
"Remote server gave us non-state event in state: %s", state_event_id
|
||||||
if event.room_id != room_id
|
)
|
||||||
]
|
continue
|
||||||
|
|
||||||
for bad_event_id, bad_room_id in bad_events:
|
state_map[(metadata.event_type, metadata.state_key)] = state_event_id
|
||||||
# This is a bogus situation, but since we may only discover it a long time
|
|
||||||
# after it happened, we try our best to carry on, by just omitting the
|
|
||||||
# bad events from the returned state set.
|
|
||||||
logger.warning(
|
|
||||||
"Remote server %s claims event %s in room %s is an auth/state "
|
|
||||||
"event in room %s",
|
|
||||||
destination,
|
|
||||||
bad_event_id,
|
|
||||||
bad_room_id,
|
|
||||||
room_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
del fetched_events[bad_event_id]
|
|
||||||
|
|
||||||
# if we couldn't get the prev event in question, that's a problem.
|
# if we couldn't get the prev event in question, that's a problem.
|
||||||
remote_event = fetched_events.get(event_id)
|
remote_event = await self._store.get_event(
|
||||||
|
event_id,
|
||||||
|
allow_none=True,
|
||||||
|
allow_rejected=True,
|
||||||
|
redact_behaviour=EventRedactBehaviour.as_is,
|
||||||
|
)
|
||||||
if not remote_event:
|
if not remote_event:
|
||||||
raise Exception("Unable to get missing prev_event %s" % (event_id,))
|
raise Exception("Unable to get missing prev_event %s" % (event_id,))
|
||||||
|
|
||||||
# missing state at that event is a warning, not a blocker
|
# missing state at that event is a warning, not a blocker
|
||||||
# XXX: this doesn't sound right? it means that we'll end up with incomplete
|
# XXX: this doesn't sound right? it means that we'll end up with incomplete
|
||||||
# state.
|
# state.
|
||||||
failed_to_fetch = desired_events - fetched_events.keys()
|
failed_to_fetch = desired_events - event_metadata.keys()
|
||||||
if failed_to_fetch:
|
if failed_to_fetch:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to fetch missing state events for %s %s",
|
"Failed to fetch missing state events for %s %s",
|
||||||
|
@ -1022,14 +1010,12 @@ class FederationEventHandler:
|
||||||
failed_to_fetch,
|
failed_to_fetch,
|
||||||
)
|
)
|
||||||
|
|
||||||
remote_state = [
|
|
||||||
fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
|
|
||||||
]
|
|
||||||
|
|
||||||
if remote_event.is_state() and remote_event.rejected_reason is None:
|
if remote_event.is_state() and remote_event.rejected_reason is None:
|
||||||
remote_state.append(remote_event)
|
state_map[
|
||||||
|
(remote_event.type, remote_event.state_key)
|
||||||
|
] = remote_event.event_id
|
||||||
|
|
||||||
return remote_state
|
return state_map
|
||||||
|
|
||||||
async def _get_state_and_persist(
|
async def _get_state_and_persist(
|
||||||
self, destination: str, room_id: str, event_id: str
|
self, destination: str, room_id: str, event_id: str
|
||||||
|
@ -1056,7 +1042,7 @@ class FederationEventHandler:
|
||||||
self,
|
self,
|
||||||
origin: str,
|
origin: str,
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
state: Optional[Iterable[EventBase]],
|
state_ids: Optional[StateMap[str]],
|
||||||
backfilled: bool = False,
|
backfilled: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Called when we have a new non-outlier event.
|
"""Called when we have a new non-outlier event.
|
||||||
|
@ -1078,7 +1064,7 @@ class FederationEventHandler:
|
||||||
|
|
||||||
event: event to be persisted
|
event: event to be persisted
|
||||||
|
|
||||||
state: Normally None, but if we are handling a gap in the graph
|
state_ids: Normally None, but if we are handling a gap in the graph
|
||||||
(ie, we are missing one or more prev_events), the resolved state at the
|
(ie, we are missing one or more prev_events), the resolved state at the
|
||||||
event
|
event
|
||||||
|
|
||||||
|
@ -1090,7 +1076,8 @@ class FederationEventHandler:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context = await self._state_handler.compute_event_context(
|
context = await self._state_handler.compute_event_context(
|
||||||
event, old_state=state
|
event,
|
||||||
|
state_ids_before_event=state_ids,
|
||||||
)
|
)
|
||||||
context = await self._check_event_auth(
|
context = await self._check_event_auth(
|
||||||
origin,
|
origin,
|
||||||
|
@ -1107,7 +1094,7 @@ class FederationEventHandler:
|
||||||
# For new (non-backfilled and non-outlier) events we check if the event
|
# For new (non-backfilled and non-outlier) events we check if the event
|
||||||
# passes auth based on the current state. If it doesn't then we
|
# passes auth based on the current state. If it doesn't then we
|
||||||
# "soft-fail" the event.
|
# "soft-fail" the event.
|
||||||
await self._check_for_soft_fail(event, state, origin=origin)
|
await self._check_for_soft_fail(event, state_ids, origin=origin)
|
||||||
|
|
||||||
await self._run_push_actions_and_persist_event(event, context, backfilled)
|
await self._run_push_actions_and_persist_event(event, context, backfilled)
|
||||||
|
|
||||||
|
@ -1589,7 +1576,7 @@ class FederationEventHandler:
|
||||||
async def _check_for_soft_fail(
|
async def _check_for_soft_fail(
|
||||||
self,
|
self,
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
state: Optional[Iterable[EventBase]],
|
state_ids: Optional[StateMap[str]],
|
||||||
origin: str,
|
origin: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Checks if we should soft fail the event; if so, marks the event as
|
"""Checks if we should soft fail the event; if so, marks the event as
|
||||||
|
@ -1597,7 +1584,7 @@ class FederationEventHandler:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event
|
event
|
||||||
state: The state at the event if we don't have all the event's prev events
|
state_ids: The state at the event if we don't have all the event's prev events
|
||||||
origin: The host the event originates from.
|
origin: The host the event originates from.
|
||||||
"""
|
"""
|
||||||
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
|
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
|
||||||
|
@ -1613,7 +1600,7 @@ class FederationEventHandler:
|
||||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||||
|
|
||||||
# Calculate the "current state".
|
# Calculate the "current state".
|
||||||
if state is not None:
|
if state_ids is not None:
|
||||||
# If we're explicitly given the state then we won't have all the
|
# If we're explicitly given the state then we won't have all the
|
||||||
# prev events, and so we have a gap in the graph. In this case
|
# prev events, and so we have a gap in the graph. In this case
|
||||||
# we want to be a little careful as we might have been down for
|
# we want to be a little careful as we might have been down for
|
||||||
|
@ -1626,17 +1613,20 @@ class FederationEventHandler:
|
||||||
# given state at the event. This should correctly handle cases
|
# given state at the event. This should correctly handle cases
|
||||||
# like bans, especially with state res v2.
|
# like bans, especially with state res v2.
|
||||||
|
|
||||||
state_sets_d = await self._state_storage.get_state_groups(
|
state_sets_d = await self._state_storage.get_state_groups_ids(
|
||||||
event.room_id, extrem_ids
|
event.room_id, extrem_ids
|
||||||
)
|
)
|
||||||
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
|
state_sets: List[StateMap[str]] = list(state_sets_d.values())
|
||||||
state_sets.append(state)
|
state_sets.append(state_ids)
|
||||||
current_states = await self._state_handler.resolve_events(
|
current_state_ids = (
|
||||||
room_version, state_sets, event
|
await self._state_resolution_handler.resolve_events_with_store(
|
||||||
|
event.room_id,
|
||||||
|
room_version,
|
||||||
|
state_sets,
|
||||||
|
event_map=None,
|
||||||
|
state_res_store=StateResolutionStore(self._store),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
current_state_ids: StateMap[str] = {
|
|
||||||
k: e.event_id for k, e in current_states.items()
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
current_state_ids = await self._state_handler.get_current_state_ids(
|
current_state_ids = await self._state_handler.get_current_state_ids(
|
||||||
event.room_id, latest_event_ids=extrem_ids
|
event.room_id, latest_event_ids=extrem_ids
|
||||||
|
|
|
@ -55,7 +55,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
|
from synapse.types import (
|
||||||
|
MutableStateMap,
|
||||||
|
Requester,
|
||||||
|
RoomAlias,
|
||||||
|
StreamToken,
|
||||||
|
UserID,
|
||||||
|
create_requester,
|
||||||
|
)
|
||||||
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
|
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
|
||||||
from synapse.util.async_helpers import Linearizer, gather_results
|
from synapse.util.async_helpers import Linearizer, gather_results
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
@ -1022,8 +1029,35 @@ class EventCreationHandler:
|
||||||
#
|
#
|
||||||
# TODO(faster_joins): figure out how this works, and make sure that the
|
# TODO(faster_joins): figure out how this works, and make sure that the
|
||||||
# old state is complete.
|
# old state is complete.
|
||||||
old_state = await self.store.get_events_as_list(state_event_ids)
|
metadata = await self.store.get_metadata_for_events(state_event_ids)
|
||||||
context = await self.state.compute_event_context(event, old_state=old_state)
|
|
||||||
|
state_map_for_event: MutableStateMap[str] = {}
|
||||||
|
for state_id in state_event_ids:
|
||||||
|
data = metadata.get(state_id)
|
||||||
|
if data is None:
|
||||||
|
# We're trying to persist a new historical batch of events
|
||||||
|
# with the given state, e.g. via
|
||||||
|
# `RoomBatchSendEventRestServlet`. The state can be inferred
|
||||||
|
# by Synapse or set directly by the client.
|
||||||
|
#
|
||||||
|
# Either way, we should have persisted all the state before
|
||||||
|
# getting here.
|
||||||
|
raise Exception(
|
||||||
|
f"State event {state_id} not found in DB,"
|
||||||
|
" Synapse should have persisted it before using it."
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.state_key is None:
|
||||||
|
raise Exception(
|
||||||
|
f"Trying to set non-state event {state_id} as state"
|
||||||
|
)
|
||||||
|
|
||||||
|
state_map_for_event[(data.event_type, data.state_key)] = state_id
|
||||||
|
|
||||||
|
context = await self.state.compute_event_context(
|
||||||
|
event,
|
||||||
|
state_ids_before_event=state_map_for_event,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
context = await self.state.compute_event_context(event)
|
context = await self.state.compute_event_context(event)
|
||||||
|
|
||||||
|
|
|
@ -261,7 +261,7 @@ class StateHandler:
|
||||||
async def compute_event_context(
|
async def compute_event_context(
|
||||||
self,
|
self,
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
old_state: Optional[Iterable[EventBase]] = None,
|
state_ids_before_event: Optional[StateMap[str]] = None,
|
||||||
partial_state: bool = False,
|
partial_state: bool = False,
|
||||||
) -> EventContext:
|
) -> EventContext:
|
||||||
"""Build an EventContext structure for a non-outlier event.
|
"""Build an EventContext structure for a non-outlier event.
|
||||||
|
@ -273,12 +273,12 @@ class StateHandler:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event:
|
event:
|
||||||
old_state: The state at the event if it can't be
|
state_ids_before_event: The event ids of the state before the event if
|
||||||
calculated from existing events. This is normally only specified
|
it can't be calculated from existing events. This is normally
|
||||||
when receiving an event from federation where we don't have the
|
only specified when receiving an event from federation where we
|
||||||
prev events for, e.g. when backfilling.
|
don't have the prev events, e.g. when backfilling.
|
||||||
partial_state: True if `old_state` is partial and omits non-critical
|
partial_state: True if `state_ids_before_event` is partial and omits
|
||||||
membership events
|
non-critical membership events
|
||||||
Returns:
|
Returns:
|
||||||
The event context.
|
The event context.
|
||||||
"""
|
"""
|
||||||
|
@ -286,13 +286,11 @@ class StateHandler:
|
||||||
assert not event.internal_metadata.is_outlier()
|
assert not event.internal_metadata.is_outlier()
|
||||||
|
|
||||||
#
|
#
|
||||||
# first of all, figure out the state before the event
|
# first of all, figure out the state before the event, unless we
|
||||||
|
# already have it.
|
||||||
#
|
#
|
||||||
if old_state:
|
if state_ids_before_event:
|
||||||
# if we're given the state before the event, then we use that
|
# if we're given the state before the event, then we use that
|
||||||
state_ids_before_event: StateMap[str] = {
|
|
||||||
(s.type, s.state_key): s.event_id for s in old_state
|
|
||||||
}
|
|
||||||
state_group_before_event = None
|
state_group_before_event = None
|
||||||
state_group_before_event_prev_group = None
|
state_group_before_event_prev_group = None
|
||||||
deltas_to_state_group_before_event = None
|
deltas_to_state_group_before_event = None
|
||||||
|
|
|
@ -16,6 +16,8 @@ import collections.abc
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||||
|
@ -26,6 +28,7 @@ from synapse.storage.database import (
|
||||||
DatabasePool,
|
DatabasePool,
|
||||||
LoggingDatabaseConnection,
|
LoggingDatabaseConnection,
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
|
make_in_list_sql_clause,
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||||
|
@ -33,6 +36,7 @@ from synapse.storage.state import StateFilter
|
||||||
from synapse.types import JsonDict, JsonMapping, StateMap
|
from synapse.types import JsonDict, JsonMapping, StateMap
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -43,6 +47,15 @@ logger = logging.getLogger(__name__)
|
||||||
MAX_STATE_DELTA_HOPS = 100
|
MAX_STATE_DELTA_HOPS = 100
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class EventMetadata:
|
||||||
|
"""Returned by `get_metadata_for_events`"""
|
||||||
|
|
||||||
|
room_id: str
|
||||||
|
event_type: str
|
||||||
|
state_key: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
|
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
|
||||||
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
|
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
|
||||||
if not v:
|
if not v:
|
||||||
|
@ -133,6 +146,52 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
return room_version
|
return room_version
|
||||||
|
|
||||||
|
async def get_metadata_for_events(
|
||||||
|
self, event_ids: Collection[str]
|
||||||
|
) -> Dict[str, EventMetadata]:
|
||||||
|
"""Get some metadata (room_id, type, state_key) for the given events.
|
||||||
|
|
||||||
|
This method is a faster alternative than fetching the full events from
|
||||||
|
the DB, and should be used when the full event is not needed.
|
||||||
|
|
||||||
|
Returns metadata for rejected and redacted events. Events that have not
|
||||||
|
been persisted are omitted from the returned dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_metadata_for_events_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
batch_ids: Collection[str],
|
||||||
|
) -> Dict[str, EventMetadata]:
|
||||||
|
clause, args = make_in_list_sql_clause(
|
||||||
|
self.database_engine, "e.event_id", batch_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = f"""
|
||||||
|
SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
|
||||||
|
LEFT JOIN state_events USING (event_id)
|
||||||
|
WHERE {clause}
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, args)
|
||||||
|
return {
|
||||||
|
event_id: EventMetadata(
|
||||||
|
room_id=room_id, event_type=event_type, state_key=state_key
|
||||||
|
)
|
||||||
|
for event_id, room_id, event_type, state_key in txn
|
||||||
|
}
|
||||||
|
|
||||||
|
result_map: Dict[str, EventMetadata] = {}
|
||||||
|
for batch_ids in batch_iter(event_ids, 1000):
|
||||||
|
result_map.update(
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
|
"get_metadata_for_events",
|
||||||
|
get_metadata_for_events_txn,
|
||||||
|
batch_ids=batch_ids,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return result_map
|
||||||
|
|
||||||
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
|
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
|
||||||
"""Get the predecessor of an upgraded room if it exists.
|
"""Get the predecessor of an upgraded room if it exists.
|
||||||
Otherwise return None.
|
Otherwise return None.
|
||||||
|
|
|
@ -276,7 +276,11 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
# federation handler wanting to backfill the fake event.
|
# federation handler wanting to backfill the fake event.
|
||||||
self.get_success(
|
self.get_success(
|
||||||
federation_event_handler._process_received_pdu(
|
federation_event_handler._process_received_pdu(
|
||||||
self.OTHER_SERVER_NAME, event, state=current_state
|
self.OTHER_SERVER_NAME,
|
||||||
|
event,
|
||||||
|
state_ids={
|
||||||
|
(e.type, e.state_key): e.event_id for e in current_state
|
||||||
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
def persist_event(self, event, state=None):
|
def persist_event(self, event, state=None):
|
||||||
"""Persist the event, with optional state"""
|
"""Persist the event, with optional state"""
|
||||||
context = self.get_success(
|
context = self.get_success(
|
||||||
self.state.compute_event_context(event, old_state=state)
|
self.state.compute_event_context(event, state_ids_before_event=state)
|
||||||
)
|
)
|
||||||
self.get_success(self.persistence.persist_event(event, context))
|
self.get_success(self.persistence.persist_event(event, context))
|
||||||
|
|
||||||
|
@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
RoomVersions.V6,
|
RoomVersions.V6,
|
||||||
)
|
)
|
||||||
|
|
||||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
state_before_gap = self.get_success(
|
||||||
|
self.state.get_current_state_ids(self.room_id)
|
||||||
|
)
|
||||||
|
|
||||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
self.persist_event(remote_event_2, state=state_before_gap)
|
||||||
|
|
||||||
# Check the new extremity is just the new remote event.
|
# Check the new extremity is just the new remote event.
|
||||||
self.assert_extremities([remote_event_2.event_id])
|
self.assert_extremities([remote_event_2.event_id])
|
||||||
|
@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
# setting. The state resolution across the old and new event will then
|
# setting. The state resolution across the old and new event will then
|
||||||
# include it, and so the resolved state won't match the new state.
|
# include it, and so the resolved state won't match the new state.
|
||||||
state_before_gap = dict(
|
state_before_gap = dict(
|
||||||
self.get_success(self.state.get_current_state(self.room_id))
|
self.get_success(self.state.get_current_state_ids(self.room_id))
|
||||||
)
|
)
|
||||||
state_before_gap.pop(("m.room.history_visibility", ""))
|
state_before_gap.pop(("m.room.history_visibility", ""))
|
||||||
|
|
||||||
context = self.get_success(
|
context = self.get_success(
|
||||||
self.state.compute_event_context(
|
self.state.compute_event_context(
|
||||||
remote_event_2, old_state=state_before_gap.values()
|
remote_event_2,
|
||||||
|
state_ids_before_event=state_before_gap,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
RoomVersions.V6,
|
RoomVersions.V6,
|
||||||
)
|
)
|
||||||
|
|
||||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
state_before_gap = self.get_success(
|
||||||
|
self.state.get_current_state_ids(self.room_id)
|
||||||
|
)
|
||||||
|
|
||||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
self.persist_event(remote_event_2, state=state_before_gap)
|
||||||
|
|
||||||
# Check the new extremity is just the new remote event.
|
# Check the new extremity is just the new remote event.
|
||||||
self.assert_extremities([remote_event_2.event_id])
|
self.assert_extremities([remote_event_2.event_id])
|
||||||
|
@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
RoomVersions.V6,
|
RoomVersions.V6,
|
||||||
)
|
)
|
||||||
|
|
||||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
state_before_gap = self.get_success(
|
||||||
|
self.state.get_current_state_ids(self.room_id)
|
||||||
|
)
|
||||||
|
|
||||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
self.persist_event(remote_event_2, state=state_before_gap)
|
||||||
|
|
||||||
# Check the new extremity is just the new remote event.
|
# Check the new extremity is just the new remote event.
|
||||||
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
|
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
|
||||||
|
@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
RoomVersions.V6,
|
RoomVersions.V6,
|
||||||
)
|
)
|
||||||
|
|
||||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
state_before_gap = self.get_success(
|
||||||
|
self.state.get_current_state_ids(self.room_id)
|
||||||
|
)
|
||||||
|
|
||||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
self.persist_event(remote_event_2, state=state_before_gap)
|
||||||
|
|
||||||
# Check the new extremity is just the new remote event.
|
# Check the new extremity is just the new remote event.
|
||||||
self.assert_extremities([remote_event_2.event_id])
|
self.assert_extremities([remote_event_2.event_id])
|
||||||
|
@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
RoomVersions.V6,
|
RoomVersions.V6,
|
||||||
)
|
)
|
||||||
|
|
||||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
state_before_gap = self.get_success(
|
||||||
|
self.state.get_current_state_ids(self.room_id)
|
||||||
|
)
|
||||||
|
|
||||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
self.persist_event(remote_event_2, state=state_before_gap)
|
||||||
|
|
||||||
# Check the new extremity is just the new remote event.
|
# Check the new extremity is just the new remote event.
|
||||||
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
|
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
|
||||||
|
@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
RoomVersions.V6,
|
RoomVersions.V6,
|
||||||
)
|
)
|
||||||
|
|
||||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
state_before_gap = self.get_success(
|
||||||
|
self.state.get_current_state_ids(self.room_id)
|
||||||
|
)
|
||||||
|
|
||||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
self.persist_event(remote_event_2, state=state_before_gap)
|
||||||
|
|
||||||
# Check the new extremity is just the new remote event.
|
# Check the new extremity is just the new remote event.
|
||||||
self.assert_extremities([local_message_event_id, remote_event_2.event_id])
|
self.assert_extremities([local_message_event_id, remote_event_2.event_id])
|
||||||
|
|
|
@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
context = yield defer.ensureDeferred(
|
context = yield defer.ensureDeferred(
|
||||||
self.state.compute_event_context(event, old_state=old_state)
|
self.state.compute_event_context(
|
||||||
|
event,
|
||||||
|
state_ids_before_event={
|
||||||
|
(e.type, e.state_key): e.event_id for e in old_state
|
||||||
|
},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||||
|
@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
context = yield defer.ensureDeferred(
|
context = yield defer.ensureDeferred(
|
||||||
self.state.compute_event_context(event, old_state=old_state)
|
self.state.compute_event_context(
|
||||||
|
event,
|
||||||
|
state_ids_before_event={
|
||||||
|
(e.type, e.state_key): e.event_id for e in old_state
|
||||||
|
},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||||
|
|
Loading…
Reference in New Issue