diff --git a/changelog.d/5788.bugfix b/changelog.d/5788.bugfix new file mode 100644 index 0000000000..5632f3cb99 --- /dev/null +++ b/changelog.d/5788.bugfix @@ -0,0 +1 @@ +Correctly handle redactions of redactions. diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 79680ee856..c6fa7f82fd 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -29,12 +29,7 @@ from synapse.api.room_versions import EventFormatVersions from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event -from synapse.logging.context import ( - LoggingContext, - PreserveLoggingContext, - make_deferred_yieldable, - run_in_background, -) +from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id from synapse.util import batch_iter @@ -342,13 +337,12 @@ class EventsWorkerStore(SQLBaseStore): log_ctx = LoggingContext.current_context() log_ctx.record_event_fetch(len(missing_events_ids)) - # Note that _enqueue_events is also responsible for turning db rows + # Note that _get_events_from_db is also responsible for turning db rows # into FrozenEvents (via _get_event_from_row), which involves seeing if # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - # _enqueue_events is a bit of a rubbish name but naming is hard. - missing_events = yield self._enqueue_events( + missing_events = yield self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -421,28 +415,28 @@ class EventsWorkerStore(SQLBaseStore): The fetch requests. Each entry consists of a list of event ids to be fetched, and a deferred to be completed once the events have been fetched. + + The deferreds are callbacked with a dictionary mapping from event id + to event row. Note that it may well contain additional events that + were not part of this request. """ with Measure(self._clock, "_fetch_event_list"): try: - event_id_lists = list(zip(*event_list))[0] - event_ids = [item for sublist in event_id_lists for item in sublist] + events_to_fetch = set( + event_id for events, _ in event_list for event_id in events + ) row_dict = self._new_transaction( - conn, "do_fetch", [], [], self._fetch_event_rows, event_ids + conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch ) # We only want to resolve deferreds from the main thread - def fire(lst, res): - for ids, d in lst: - if not d.called: - try: - with PreserveLoggingContext(): - d.callback([res[i] for i in ids if i in res]) - except Exception: - logger.exception("Failed to callback") + def fire(): + for _, d in event_list: + d.callback(row_dict) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire, event_list, row_dict) + self.hs.get_reactor().callFromThread(fire) except Exception as e: logger.exception("do_fetch") @@ -457,13 +451,98 @@ class EventsWorkerStore(SQLBaseStore): self.hs.get_reactor().callFromThread(fire, event_list, e) @defer.inlineCallbacks - def _enqueue_events(self, events, allow_rejected=False): + def _get_events_from_db(self, event_ids, allow_rejected=False): + """Fetch a bunch of events from the database. + + Returned events will be added to the cache for future lookups. + + Args: + event_ids (Iterable[str]): The event_ids of the events to fetch + allow_rejected (bool): Whether to include rejected events + + Returns: + Deferred[Dict[str, _EventCacheEntry]]: + map from event id to result. May return extra events which + weren't asked for. + """ + fetched_events = {} + events_to_fetch = event_ids + + while events_to_fetch: + row_map = yield self._enqueue_events(events_to_fetch) + + # we need to recursively fetch any redactions of those events + redaction_ids = set() + for event_id in events_to_fetch: + row = row_map.get(event_id) + fetched_events[event_id] = row + if row: + redaction_ids.update(row["redactions"]) + + events_to_fetch = redaction_ids.difference(fetched_events.keys()) + if events_to_fetch: + logger.debug("Also fetching redaction events %s", events_to_fetch) + + # build a map from event_id to EventBase + event_map = {} + for event_id, row in fetched_events.items(): + if not row: + continue + assert row["event_id"] == event_id + + rejected_reason = row["rejected_reason"] + + if not allow_rejected and rejected_reason: + continue + + d = json.loads(row["json"]) + internal_metadata = json.loads(row["internal_metadata"]) + + format_version = row["format_version"] + if format_version is None: + # This means that we stored the event before we had the concept + # of a event format version, so it must be a V1 event. + format_version = EventFormatVersions.V1 + + original_ev = event_type_from_format_version(format_version)( + event_dict=d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + event_map[event_id] = original_ev + + # finally, we can decide whether each one nededs redacting, and build + # the cache entries. + result_map = {} + for event_id, original_ev in event_map.items(): + redactions = fetched_events[event_id]["redactions"] + redacted_event = self._maybe_redact_event_row( + original_ev, redactions, event_map + ) + + cache_entry = _EventCacheEntry( + event=original_ev, redacted_event=redacted_event + ) + + self._get_event_cache.prefill((event_id,), cache_entry) + result_map[event_id] = cache_entry + + return result_map + + @defer.inlineCallbacks + def _enqueue_events(self, events): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. + + Args: + events (Iterable[str]): events to be fetched. + + Returns: + Deferred[Dict[str, Dict]]: map from event id to row data from the database. + May contain events that weren't requested. """ - if not events: - return {} events_d = defer.Deferred() with self._event_fetch_lock: @@ -482,32 +561,12 @@ class EventsWorkerStore(SQLBaseStore): "fetch_events", self.runWithConnection, self._do_fetch ) - logger.debug("Loading %d events", len(events)) + logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): - rows = yield events_d - logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) + row_map = yield events_d + logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) - if not allow_rejected: - rows[:] = [r for r in rows if r["rejected_reason"] is None] - - res = yield make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self._get_event_from_row, - row["internal_metadata"], - row["json"], - row["redactions"], - rejected_reason=row["rejected_reason"], - format_version=row["format_version"], - ) - for row in rows - ], - consumeErrors=True, - ) - ) - - return {e.event.event_id: e for e in res if e} + return row_map def _fetch_event_rows(self, txn, event_ids): """Fetch event rows from the database @@ -580,50 +639,7 @@ class EventsWorkerStore(SQLBaseStore): return event_dict - @defer.inlineCallbacks - def _get_event_from_row( - self, internal_metadata, js, redactions, format_version, rejected_reason=None - ): - """Parse an event row which has been read from the database - - Args: - internal_metadata (str): json-encoded internal_metadata column - js (str): json-encoded event body from event_json - redactions (list[str]): a list of the events which claim to have redacted - this event, from the redactions table - format_version: (str): the 'format_version' column - rejected_reason (str|None): the reason this event was rejected, if any - - Returns: - _EventCacheEntry - """ - with Measure(self._clock, "_get_event_from_row"): - d = json.loads(js) - internal_metadata = json.loads(internal_metadata) - - if format_version is None: - # This means that we stored the event before we had the concept - # of a event format version, so it must be a V1 event. - format_version = EventFormatVersions.V1 - - original_ev = event_type_from_format_version(format_version)( - event_dict=d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - redacted_event = yield self._maybe_redact_event_row(original_ev, redactions) - - cache_entry = _EventCacheEntry( - event=original_ev, redacted_event=redacted_event - ) - - self._get_event_cache.prefill((original_ev.event_id,), cache_entry) - - return cache_entry - - @defer.inlineCallbacks - def _maybe_redact_event_row(self, original_ev, redactions): + def _maybe_redact_event_row(self, original_ev, redactions, event_map): """Given an event object and a list of possible redacting event ids, determine whether to honour any of those redactions and if so return a redacted event. @@ -631,6 +647,8 @@ class EventsWorkerStore(SQLBaseStore): Args: original_ev (EventBase): redactions (iterable[str]): list of event ids of potential redaction events + event_map (dict[str, EventBase]): other events which have been fetched, in + which we can look up the redaaction events. Map from event id to event. Returns: Deferred[EventBase|None]: if the event should be redacted, a pruned @@ -640,15 +658,9 @@ class EventsWorkerStore(SQLBaseStore): # we choose to ignore redactions of m.room.create events. return None - if original_ev.type == "m.room.redaction": - # ... and redaction events - return None - - redaction_map = yield self._get_events_from_cache_or_db(redactions) - for redaction_id in redactions: - redaction_entry = redaction_map.get(redaction_id) - if not redaction_entry: + redaction_event = event_map.get(redaction_id) + if not redaction_event or redaction_event.rejected_reason: # we don't have the redaction event, or the redaction event was not # authorized. logger.debug( @@ -658,7 +670,6 @@ class EventsWorkerStore(SQLBaseStore): ) continue - redaction_event = redaction_entry.event if redaction_event.room_id != original_ev.room_id: logger.debug( "%s was redacted by %s but redaction was in a different room!", diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 8488b6edc8..d961b81d48 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -17,6 +17,8 @@ from mock import Mock +from twisted.internet import defer + from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.types import RoomID, UserID @@ -216,3 +218,71 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, event.unsigned["redacted_because"], ) + + def test_circular_redaction(self): + redaction_event_id1 = "$redaction1_id:test" + redaction_event_id2 = "$redaction2_id:test" + + class EventIdManglingBuilder: + def __init__(self, base_builder, event_id): + self._base_builder = base_builder + self._event_id = event_id + + @defer.inlineCallbacks + def build(self, prev_event_ids): + built_event = yield self._base_builder.build(prev_event_ids) + built_event.event_id = self._event_id + built_event._event_dict["event_id"] = self._event_id + return built_event + + @property + def room_id(self): + return self._base_builder.room_id + + event_1, context_1 = self.get_success( + self.event_creation_handler.create_new_client_event( + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id2, + }, + ), + redaction_event_id1, + ) + ) + ) + + self.get_success(self.store.persist_event(event_1, context_1)) + + event_2, context_2 = self.get_success( + self.event_creation_handler.create_new_client_event( + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id1, + }, + ), + redaction_event_id2, + ) + ) + ) + self.get_success(self.store.persist_event(event_2, context_2)) + + # fetch one of the redactions + fetched = self.get_success(self.store.get_event(redaction_event_id1)) + + # it should have been redacted + self.assertEqual(fetched.unsigned["redacted_by"], redaction_event_id2) + self.assertEqual( + fetched.unsigned["redacted_because"].event_id, redaction_event_id2 + )