Fix a bug where redactions were not being sent over federation if we did not have the original event. (#13813)

pull/14145/head
Shay 2022-10-11 11:18:45 -07:00 committed by GitHub
parent 6a92944854
commit a86b2f6837
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 62 additions and 38 deletions

1
changelog.d/13813.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where redactions were not being sent over federation if we did not have the original event.

View File

@ -353,21 +353,25 @@ class FederationSender(AbstractFederationSender):
last_token = await self.store.get_federation_out_pos("events")
(
next_token,
events,
event_to_received_ts,
) = await self.store.get_all_new_events_stream(
) = await self.store.get_all_new_event_ids_stream(
last_token, self._last_poked_id, limit=100
)
event_ids = event_to_received_ts.keys()
event_entries = await self.store.get_unredacted_events_from_cache_or_db(
event_ids
)
logger.debug(
"Handling %i -> %i: %i events to send (current id %i)",
last_token,
next_token,
len(events),
len(event_entries),
self._last_poked_id,
)
if not events and next_token >= self._last_poked_id:
if not event_entries and next_token >= self._last_poked_id:
logger.debug("All events processed")
break
@ -508,8 +512,14 @@ class FederationSender(AbstractFederationSender):
await handle_event(event)
events_by_room: Dict[str, List[EventBase]] = {}
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
for event_id in event_ids:
# `event_entries` is unsorted, so we have to iterate over `event_ids`
# to ensure the events are in the right order
event_cache = event_entries.get(event_id)
if event_cache:
event = event_cache.event
events_by_room.setdefault(event.room_id, []).append(event)
await make_deferred_yieldable(
defer.gatherResults(
@ -524,9 +534,10 @@ class FederationSender(AbstractFederationSender):
logger.debug("Successfully handled up to %i", next_token)
await self.store.update_federation_out_pos("events", next_token)
if events:
if event_entries:
now = self.clock.time_msec()
ts = event_to_received_ts[events[-1].event_id]
last_id = next(reversed(event_ids))
ts = event_to_received_ts[last_id]
assert ts is not None
synapse.metrics.event_processing_lag.labels(
@ -536,7 +547,7 @@ class FederationSender(AbstractFederationSender):
"federation_sender"
).set(ts)
events_processed_counter.inc(len(events))
events_processed_counter.inc(len(event_entries))
event_processing_loop_room_count.labels("federation_sender").inc(
len(events_by_room)

View File

@ -109,10 +109,13 @@ class ApplicationServicesHandler:
last_token = await self.store.get_appservice_last_pos()
(
upper_bound,
events,
event_to_received_ts,
) = await self.store.get_all_new_events_stream(
last_token, self.current_max, limit=100, get_prev_content=True
) = await self.store.get_all_new_event_ids_stream(
last_token, self.current_max, limit=100
)
events = await self.store.get_events_as_list(
event_to_received_ts.keys(), get_prev_content=True
)
events_by_room: Dict[str, List[EventBase]] = {}

View File

@ -474,7 +474,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
event_entry_map = await self._get_events_from_cache_or_db(
event_entry_map = await self.get_unredacted_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@ -509,7 +509,9 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
event_map = await self._get_events_from_cache_or_db([redacted_event_id])
event_map = await self.get_unredacted_events_from_cache_or_db(
[redacted_event_id]
)
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@ -588,11 +590,16 @@ class EventsWorkerStore(SQLBaseStore):
return events
@cancellable
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
async def get_unredacted_events_from_cache_or_db(
self,
event_ids: Iterable[str],
allow_rejected: bool = False,
) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
Note that the events pulled by this function will not have any redactions
applied, and no guarantee is made about the ordering of the events returned.
If events are pulled from the database, they will be cached for future lookups.
Unknown events are omitted from the response.

View File

@ -1024,28 +1024,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"after": {"event_ids": events_after, "token": end_token},
}
async def get_all_new_events_stream(
self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False
) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]:
async def get_all_new_event_ids_stream(
self,
from_id: int,
current_id: int,
limit: int,
) -> Tuple[int, Dict[str, Optional[int]]]:
"""Get all new events
Returns all events with from_id < stream_ordering <= current_id.
Returns all event ids with from_id < stream_ordering <= current_id.
Args:
from_id: the stream_ordering of the last event we processed
current_id: the stream_ordering of the most recently processed event
limit: the maximum number of events to return
get_prev_content: whether to fetch previous event content
Returns:
A tuple of (next_id, events, event_to_received_ts), where `next_id`
A tuple of (next_id, event_to_received_ts), where `next_id`
is the next value to pass as `from_id` (it will either be the
stream_ordering of the last returned event, or, if fewer than `limit`
events were found, the `current_id`). The `event_to_received_ts` is
a dictionary mapping event ID to the event `received_ts`.
a dictionary mapping event ID to the event `received_ts`, sorted by ascending
stream_ordering.
"""
def get_all_new_events_stream_txn(
def get_all_new_event_ids_stream_txn(
txn: LoggingTransaction,
) -> Tuple[int, Dict[str, Optional[int]]]:
sql = (
@ -1070,15 +1073,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, event_to_received_ts
upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
"get_all_new_event_ids_stream", get_all_new_event_ids_stream_txn
)
events = await self.get_events_as_list(
event_to_received_ts.keys(),
get_prev_content=get_prev_content,
)
return upper_bound, events, event_to_received_ts
return upper_bound, event_to_received_ts
async def get_federation_out_pos(self, typ: str) -> int:
if self._need_to_reset_federation_stream_positions:

View File

@ -76,9 +76,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
self.mock_store.get_all_new_events_stream.side_effect = [
make_awaitable((0, [], {})),
make_awaitable((1, [event], {event.event_id: 0})),
self.mock_store.get_all_new_event_ids_stream.side_effect = [
make_awaitable((0, {})),
make_awaitable((1, {event.event_id: 0})),
]
self.mock_store.get_events_as_list.side_effect = [
make_awaitable([]),
make_awaitable([event]),
]
self.handler.notify_interested_services(RoomStreamToken(None, 1))
@ -95,10 +99,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_all_new_events_stream.side_effect = [
make_awaitable((0, [event], {event.event_id: 0})),
self.mock_store.get_all_new_event_ids_stream.side_effect = [
make_awaitable((0, {event.event_id: 0})),
]
self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@ -112,7 +116,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_all_new_events_stream.side_effect = [
self.mock_store.get_all_new_event_ids_stream.side_effect = [
make_awaitable((0, [event], {event.event_id: 0})),
]