Split `on_receive_pdu` in half (#10640)
Here we split on_receive_pdu into two functions (on_receive_pdu and process_pulled_event), rather than having both cases in the same method. There's a tiny bit of overlap, but not that much.pull/10664/head
parent
50af1efe4b
commit
e81d62009e
|
@ -0,0 +1 @@
|
|||
Clean up some of the federation event authentication code for clarity.
|
|
@ -1005,9 +1005,7 @@ class FederationServer(FederationBase):
|
|||
async with lock:
|
||||
logger.info("handling received PDU: %s", event)
|
||||
try:
|
||||
await self.handler.on_receive_pdu(
|
||||
origin, event, sent_to_us_directly=True
|
||||
)
|
||||
await self.handler.on_receive_pdu(origin, event)
|
||||
except FederationError as e:
|
||||
# XXX: Ideally we'd inform the remote we failed to process
|
||||
# the event, but we can't return an error in the transaction
|
||||
|
|
|
@ -203,18 +203,13 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||
|
||||
async def on_receive_pdu(
|
||||
self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
|
||||
) -> None:
|
||||
"""Process a PDU received via a federation /send/ transaction, or
|
||||
via backfill of missing prev_events
|
||||
async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
|
||||
"""Process a PDU received via a federation /send/ transaction
|
||||
|
||||
Args:
|
||||
origin: server which initiated the /send/ transaction. Will
|
||||
be used to fetch missing events or state.
|
||||
pdu: received PDU
|
||||
sent_to_us_directly: True if this event was pushed to us; False if
|
||||
we pulled it as the result of a missing prev_event.
|
||||
"""
|
||||
|
||||
room_id = pdu.room_id
|
||||
|
@ -276,8 +271,6 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
return None
|
||||
|
||||
state = None
|
||||
|
||||
# Check that the event passes auth based on the state at the event. This is
|
||||
# done for events that are to be added to the timeline (non-outliers).
|
||||
#
|
||||
|
@ -285,7 +278,6 @@ class FederationHandler(BaseHandler):
|
|||
# - Fetching any missing prev events to fill in gaps in the graph
|
||||
# - Fetching state if we have a hole in the graph
|
||||
if not pdu.internal_metadata.is_outlier():
|
||||
if sent_to_us_directly:
|
||||
prevs = set(pdu.prev_event_ids())
|
||||
seen = await self.store.have_events_in_timeline(prevs)
|
||||
missing_prevs = prevs - seen
|
||||
|
@ -351,17 +343,7 @@ class FederationHandler(BaseHandler):
|
|||
affected=pdu.event_id,
|
||||
)
|
||||
|
||||
else:
|
||||
state = await self._resolve_state_at_missing_prevs(origin, pdu)
|
||||
|
||||
# A second round of checks for all events. Check that the event passes auth
|
||||
# based on `auth_events`, this allows us to assert that the event would
|
||||
# have been allowed at some point. If an event passes this check its OK
|
||||
# for it to be used as part of a returned `/state` request, as either
|
||||
# a) we received the event as part of the original join and so trust it, or
|
||||
# b) we'll do a state resolution with existing state before it becomes
|
||||
# part of the "current state", which adds more protection.
|
||||
await self._process_received_pdu(origin, pdu, state=state)
|
||||
await self._process_received_pdu(origin, pdu, state=None)
|
||||
|
||||
async def _get_missing_events_for_pdu(
|
||||
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
|
||||
|
@ -461,24 +443,7 @@ class FederationHandler(BaseHandler):
|
|||
return
|
||||
|
||||
logger.info("Got %d prev_events", len(missing_events))
|
||||
|
||||
# We want to sort these by depth so we process them and
|
||||
# tell clients about them in order.
|
||||
missing_events.sort(key=lambda x: x.depth)
|
||||
|
||||
for ev in missing_events:
|
||||
logger.info("Handling received prev_event %s", ev)
|
||||
with nested_logging_context(ev.event_id):
|
||||
try:
|
||||
await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
|
||||
except FederationError as e:
|
||||
if e.code == 403:
|
||||
logger.warning(
|
||||
"Received prev_event %s failed history check.",
|
||||
ev.event_id,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
await self._process_pulled_events(origin, missing_events)
|
||||
|
||||
async def _get_state_for_room(
|
||||
self,
|
||||
|
@ -1395,6 +1360,81 @@ class FederationHandler(BaseHandler):
|
|||
event_infos,
|
||||
)
|
||||
|
||||
async def _process_pulled_events(
|
||||
self, origin: str, events: Iterable[EventBase]
|
||||
) -> None:
|
||||
"""Process a batch of events we have pulled from a remote server
|
||||
|
||||
Pulls in any events required to auth the events, persists the received events,
|
||||
and notifies clients, if appropriate.
|
||||
|
||||
Assumes the events have already had their signatures and hashes checked.
|
||||
|
||||
Params:
|
||||
origin: The server we received these events from
|
||||
events: The received events.
|
||||
"""
|
||||
|
||||
# We want to sort these by depth so we process them and
|
||||
# tell clients about them in order.
|
||||
sorted_events = sorted(events, key=lambda x: x.depth)
|
||||
|
||||
for ev in sorted_events:
|
||||
with nested_logging_context(ev.event_id):
|
||||
await self._process_pulled_event(origin, ev)
|
||||
|
||||
async def _process_pulled_event(self, origin: str, event: EventBase) -> None:
|
||||
"""Process a single event that we have pulled from a remote server
|
||||
|
||||
Pulls in any events required to auth the event, persists the received event,
|
||||
and notifies clients, if appropriate.
|
||||
|
||||
Assumes the event has already had its signatures and hashes checked.
|
||||
|
||||
This is somewhat equivalent to on_receive_pdu, but applies somewhat different
|
||||
logic in the case that we are missing prev_events (in particular, it just
|
||||
requests the state at that point, rather than triggering a get_missing_events) -
|
||||
so is appropriate when we have pulled the event from a remote server, rather
|
||||
than having it pushed to us.
|
||||
|
||||
Params:
|
||||
origin: The server we received this event from
|
||||
events: The received event
|
||||
"""
|
||||
logger.info("Processing pulled event %s", event)
|
||||
|
||||
# these should not be outliers.
|
||||
assert not event.internal_metadata.is_outlier()
|
||||
|
||||
event_id = event.event_id
|
||||
|
||||
existing = await self.store.get_event(
|
||||
event_id, allow_none=True, allow_rejected=True
|
||||
)
|
||||
if existing:
|
||||
if not existing.internal_metadata.is_outlier():
|
||||
logger.info(
|
||||
"Ignoring received event %s which we have already seen",
|
||||
event_id,
|
||||
)
|
||||
return
|
||||
logger.info("De-outliering event %s", event_id)
|
||||
|
||||
try:
|
||||
self._sanity_check_event(event)
|
||||
except SynapseError as err:
|
||||
logger.warning("Event %s failed sanity check: %s", event_id, err)
|
||||
return
|
||||
|
||||
try:
|
||||
state = await self._resolve_state_at_missing_prevs(origin, event)
|
||||
await self._process_received_pdu(origin, event, state=state)
|
||||
except FederationError as e:
|
||||
if e.code == 403:
|
||||
logger.warning("Pulled event %s failed history check.", event_id)
|
||||
else:
|
||||
raise
|
||||
|
||||
async def _resolve_state_at_missing_prevs(
|
||||
self, dest: str, event: EventBase
|
||||
) -> Optional[Iterable[EventBase]]:
|
||||
|
@ -1780,7 +1820,7 @@ class FederationHandler(BaseHandler):
|
|||
p,
|
||||
)
|
||||
with nested_logging_context(p.event_id):
|
||||
await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
|
||||
await self.on_receive_pdu(origin, p)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
|
||||
|
|
|
@ -85,11 +85,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
|||
|
||||
# Send the join, it should return None (which is not an error)
|
||||
self.assertEqual(
|
||||
self.get_success(
|
||||
self.handler.on_receive_pdu(
|
||||
"test.serv", join_event, sent_to_us_directly=True
|
||||
)
|
||||
),
|
||||
self.get_success(self.handler.on_receive_pdu("test.serv", join_event)),
|
||||
None,
|
||||
)
|
||||
|
||||
|
@ -135,9 +131,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
|||
|
||||
with LoggingContext("test-context"):
|
||||
failure = self.get_failure(
|
||||
self.handler.on_receive_pdu(
|
||||
"test.serv", lying_event, sent_to_us_directly=True
|
||||
),
|
||||
self.handler.on_receive_pdu("test.serv", lying_event),
|
||||
FederationError,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue