Get rid of `_auth_and_persist_event` (#10781)
This is only called in two places, and the code seems much clearer without it.pull/10733/head
parent
03caba6577
commit
abedf7d77f
|
@ -0,0 +1 @@
|
||||||
|
Clean up some of the federation event authentication code for clarity.
|
|
@ -909,12 +909,18 @@ class FederationEventHandler:
|
||||||
context = await self._state_handler.compute_event_context(
|
context = await self._state_handler.compute_event_context(
|
||||||
event, old_state=state
|
event, old_state=state
|
||||||
)
|
)
|
||||||
await self._auth_and_persist_event(
|
context = await self._check_event_auth(
|
||||||
origin, event, context, state=state, backfilled=backfilled
|
origin,
|
||||||
|
event,
|
||||||
|
context,
|
||||||
|
state=state,
|
||||||
|
backfilled=backfilled,
|
||||||
)
|
)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
|
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
|
||||||
|
|
||||||
|
await self._run_push_actions_and_persist_event(event, context, backfilled)
|
||||||
|
|
||||||
if backfilled:
|
if backfilled:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1239,51 +1245,6 @@ class FederationEventHandler:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _auth_and_persist_event(
|
|
||||||
self,
|
|
||||||
origin: str,
|
|
||||||
event: EventBase,
|
|
||||||
context: EventContext,
|
|
||||||
state: Optional[Iterable[EventBase]] = None,
|
|
||||||
claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
|
|
||||||
backfilled: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Process an event by performing auth checks and then persisting to the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
origin: The host the event originates from.
|
|
||||||
event: The event itself.
|
|
||||||
context:
|
|
||||||
The event context.
|
|
||||||
|
|
||||||
state:
|
|
||||||
The state events used to check the event for soft-fail. If this is
|
|
||||||
not provided the current state events will be used.
|
|
||||||
|
|
||||||
claimed_auth_event_map:
|
|
||||||
A map of (type, state_key) => event for the event's claimed auth_events.
|
|
||||||
Possibly incomplete, and possibly including events that are not yet
|
|
||||||
persisted, or authed, or in the right room.
|
|
||||||
|
|
||||||
Only populated when populating outliers.
|
|
||||||
|
|
||||||
backfilled: True if the event was backfilled.
|
|
||||||
"""
|
|
||||||
# claimed_auth_event_map should be given iff the event is an outlier
|
|
||||||
assert bool(claimed_auth_event_map) == event.internal_metadata.outlier
|
|
||||||
|
|
||||||
context = await self._check_event_auth(
|
|
||||||
origin,
|
|
||||||
event,
|
|
||||||
context,
|
|
||||||
state=state,
|
|
||||||
claimed_auth_event_map=claimed_auth_event_map,
|
|
||||||
backfilled=backfilled,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._run_push_actions_and_persist_event(event, context, backfilled)
|
|
||||||
|
|
||||||
async def _check_event_auth(
|
async def _check_event_auth(
|
||||||
self,
|
self,
|
||||||
origin: str,
|
origin: str,
|
||||||
|
@ -1558,39 +1519,45 @@ class FederationEventHandler:
|
||||||
event.room_id, [e.event_id for e in remote_auth_chain]
|
event.room_id, [e.event_id for e in remote_auth_chain]
|
||||||
)
|
)
|
||||||
|
|
||||||
for e in remote_auth_chain:
|
for auth_event in remote_auth_chain:
|
||||||
if e.event_id in seen_remotes:
|
if auth_event.event_id in seen_remotes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if e.event_id == event.event_id:
|
if auth_event.event_id == event.event_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
auth_ids = e.auth_event_ids()
|
auth_ids = auth_event.auth_event_ids()
|
||||||
auth = {
|
auth = {
|
||||||
(e.type, e.state_key): e
|
(e.type, e.state_key): e
|
||||||
for e in remote_auth_chain
|
for e in remote_auth_chain
|
||||||
if e.event_id in auth_ids or e.type == EventTypes.Create
|
if e.event_id in auth_ids or e.type == EventTypes.Create
|
||||||
}
|
}
|
||||||
e.internal_metadata.outlier = True
|
auth_event.internal_metadata.outlier = True
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"_check_event_auth %s missing_auth: %s",
|
"_check_event_auth %s missing_auth: %s",
|
||||||
event.event_id,
|
event.event_id,
|
||||||
e.event_id,
|
auth_event.event_id,
|
||||||
)
|
)
|
||||||
missing_auth_event_context = (
|
missing_auth_event_context = (
|
||||||
await self._state_handler.compute_event_context(e)
|
await self._state_handler.compute_event_context(auth_event)
|
||||||
)
|
)
|
||||||
await self._auth_and_persist_event(
|
|
||||||
|
missing_auth_event_context = await self._check_event_auth(
|
||||||
origin,
|
origin,
|
||||||
e,
|
auth_event,
|
||||||
missing_auth_event_context,
|
missing_auth_event_context,
|
||||||
claimed_auth_event_map=auth,
|
claimed_auth_event_map=auth,
|
||||||
)
|
)
|
||||||
|
await self.persist_events_and_notify(
|
||||||
|
event.room_id, [(auth_event, missing_auth_event_context)]
|
||||||
|
)
|
||||||
|
|
||||||
if e.event_id in event_auth_events:
|
if auth_event.event_id in event_auth_events:
|
||||||
auth_events[(e.type, e.state_key)] = e
|
auth_events[
|
||||||
|
(auth_event.type, auth_event.state_key)
|
||||||
|
] = auth_event
|
||||||
except AuthError:
|
except AuthError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1733,10 +1700,13 @@ class FederationEventHandler:
|
||||||
context: The event context.
|
context: The event context.
|
||||||
backfilled: True if the event was backfilled.
|
backfilled: True if the event was backfilled.
|
||||||
"""
|
"""
|
||||||
|
# this method should not be called on outliers (those code paths call
|
||||||
|
# persist_events_and_notify directly.)
|
||||||
|
assert not event.internal_metadata.outlier
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if (
|
if (
|
||||||
not event.internal_metadata.is_outlier()
|
not backfilled
|
||||||
and not backfilled
|
|
||||||
and not context.rejected
|
and not context.rejected
|
||||||
and (await self._store.get_min_depth(event.room_id)) <= event.depth
|
and (await self._store.get_min_depth(event.room_id)) <= event.depth
|
||||||
):
|
):
|
||||||
|
|
|
@ -76,9 +76,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.handler = self.homeserver.get_federation_handler()
|
self.handler = self.homeserver.get_federation_handler()
|
||||||
federation_event_handler = self.homeserver.get_federation_event_handler()
|
federation_event_handler = self.homeserver.get_federation_event_handler()
|
||||||
federation_event_handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
|
|
||||||
context
|
async def _check_event_auth(
|
||||||
)
|
origin,
|
||||||
|
event,
|
||||||
|
context,
|
||||||
|
state=None,
|
||||||
|
claimed_auth_event_map=None,
|
||||||
|
backfilled=False,
|
||||||
|
):
|
||||||
|
return context
|
||||||
|
|
||||||
|
federation_event_handler._check_event_auth = _check_event_auth
|
||||||
self.client = self.homeserver.get_federation_client()
|
self.client = self.homeserver.get_federation_client()
|
||||||
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
|
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
|
||||||
pdus
|
pdus
|
||||||
|
|
Loading…
Reference in New Issue