From 7469824d5838577f5a07aec6ab73b457459d8b4a Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Tue, 28 Jun 2022 13:13:44 +0100
Subject: [PATCH] Fix serialization errors when rotating notifications (#13118)

---
 changelog.d/13118.misc                        |   1 +
 .../databases/main/event_push_actions.py      | 201 ++++++++++++------
 synapse/storage/databases/main/receipts.py    |  13 +-
 .../delta/72/01event_push_summary_receipt.sql |  35 +++
 tests/storage/test_event_push_actions.py      |  35 ++-
 5 files changed, 202 insertions(+), 83 deletions(-)
 create mode 100644 changelog.d/13118.misc
 create mode 100644 synapse/storage/schema/main/delta/72/01event_push_summary_receipt.sql

diff --git a/changelog.d/13118.misc b/changelog.d/13118.misc
new file mode 100644
index 0000000000..3bb51962e7
--- /dev/null
+++ b/changelog.d/13118.misc
@@ -0,0 +1 @@
+Reduce DB usage of `/sync` when a large number of unread messages have recently been sent in a room.
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 10a7962382..80ca2fd0b6 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -233,14 +233,30 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBas
 
         counts = NotifCounts()
 
-        # First we pull the counts from the summary table
+        # First we pull the counts from the summary table.
+        #
+        # We check that `last_receipt_stream_ordering` matches the stream
+        # ordering given. If it doesn't match then a new read receipt has arrived and
+        # we haven't yet updated the counts in `event_push_summary` to reflect
+        # that; in that case we simply ignore `event_push_summary` counts
+        # and do a manual count of all of the rows in the `event_push_actions` table
+        # for this user/room.
+        #
+        # If `last_receipt_stream_ordering` is null then that means it's up to
+        # date (as the row was written by an older version of Synapse that
+        # updated `event_push_summary` synchronously when persisting a new read
+        # receipt).
         txn.execute(
             """
                 SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
                 FROM event_push_summary
-                WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+                WHERE room_id = ? AND user_id = ?
+                AND (
+                    (last_receipt_stream_ordering IS NULL AND stream_ordering > ?)
+                    OR last_receipt_stream_ordering = ?
+                )
             """,
-            (room_id, user_id, stream_ordering),
+            (room_id, user_id, stream_ordering, stream_ordering),
         )
         row = txn.fetchone()
 
@@ -263,9 +279,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBas
         if row:
             counts.highlight_count += row[0]
 
-        # Finally we need to count push actions that haven't been summarized
-        # yet.
-        # We only want to pull out push actions that we haven't summarized yet.
+        # Finally we need to count push actions that aren't included in the
+        # summary returned above, e.g. recent events that haven't been
+        # summarized yet, or the summary is empty due to a recent read receipt.
         stream_ordering = max(stream_ordering, summary_stream_ordering)
         notify_count, unread_count = self._get_notif_unread_count_for_user_room(
             txn, room_id, user_id, stream_ordering
@@ -800,6 +816,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBas
         self._doing_notif_rotation = True
 
         try:
+            # First we recalculate push summaries and delete stale push actions
+            # for rooms/users with new receipts.
+            while True:
+                logger.debug("Handling new receipts")
+
+                caught_up = await self.db_pool.runInteraction(
+                    "_handle_new_receipts_for_notifs_txn",
+                    self._handle_new_receipts_for_notifs_txn,
+                )
+                if caught_up:
+                    break
+
+            # Then we update the event push summaries for any new events
             while True:
                 logger.info("Rotating notifications")
 
@@ -810,10 +839,110 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBas
                     break
                 await self.hs.get_clock().sleep(self._rotate_delay)
 
+            # Finally we clear out old event push actions.
             await self._remove_old_push_actions_that_have_rotated()
         finally:
             self._doing_notif_rotation = False
 
+    def _handle_new_receipts_for_notifs_txn(self, txn: LoggingTransaction) -> bool:
+        """Check for new read receipts and delete from event push actions.
+
+        Any push actions which predate the user's most recent read receipt are
+        now redundant, so we can remove them from `event_push_actions` and
+        update `event_push_summary`.
+        """
+
+        limit = 100
+
+        min_stream_id = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="event_push_summary_last_receipt_stream_id",
+            keyvalues={},
+            retcol="stream_id",
+        )
+
+        sql = """
+            SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
+            FROM receipts_linearized AS r
+            INNER JOIN events AS e USING (event_id)
+            WHERE r.stream_id > ? AND user_id LIKE ?
+            ORDER BY r.stream_id ASC
+            LIMIT ?
+        """
+
+        # We only want local users, so we add a dodgy filter to the above query
+        # and recheck it below.
+        user_filter = "%:" + self.hs.hostname
+
+        txn.execute(
+            sql,
+            (
+                min_stream_id,
+                user_filter,
+                limit,
+            ),
+        )
+        rows = txn.fetchall()
+
+        # For each new read receipt we delete push actions from before it and
+        # recalculate the summary.
+        for _, room_id, user_id, stream_ordering in rows:
+            # Only handle our own read receipts.
+            if not self.hs.is_mine_id(user_id):
+                continue
+
+            txn.execute(
+                """
+                DELETE FROM event_push_actions
+                WHERE room_id = ?
+                    AND user_id = ?
+                    AND stream_ordering <= ?
+                    AND highlight = 0
+                """,
+                (room_id, user_id, stream_ordering),
+            )
+
+            old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="event_push_summary_stream_ordering",
+                keyvalues={},
+                retcol="stream_ordering",
+            )
+
+            notif_count, unread_count = self._get_notif_unread_count_for_user_room(
+                txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
+            )
+
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="event_push_summary",
+                keyvalues={"room_id": room_id, "user_id": user_id},
+                values={
+                    "notif_count": notif_count,
+                    "unread_count": unread_count,
+                    "stream_ordering": old_rotate_stream_ordering,
+                    "last_receipt_stream_ordering": stream_ordering,
+                },
+            )
+
+        # We always update `event_push_summary_last_receipt_stream_id` to
+        # ensure that we don't rescan the same receipts for remote users.
+        #
+        # This requires repeatable read to be safe, as we need the
+        # `MAX(stream_id)` to not include any new rows that have been committed
+        # since the start of the transaction (since those rows won't have been
+        # returned by the query above). Alternatively we could query the max
+        # stream ID at the start of the transaction and bound everything by
+        # that.
+        txn.execute(
+            """
+            UPDATE event_push_summary_last_receipt_stream_id
+            SET stream_id = (SELECT COALESCE(MAX(stream_id), 0) FROM receipts_linearized)
+            """
+        )
+
+        return len(rows) < limit
+
     def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
         """Archives older notifications into event_push_summary. Returns whether
         the archiving process has caught up or not.
@@ -1033,66 +1162,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBas
             if done:
                 break
 
-    def _remove_old_push_actions_before_txn(
-        self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
-    ) -> None:
-        """
-        Purges old push actions for a user and room before a given
-        stream_ordering.
-
-        We however keep a months worth of highlighted notifications, so that
-        users can still get a list of recent highlights.
-
-        Args:
-            txn: The transaction
-            room_id: Room ID to delete from
-            user_id: user ID to delete for
-            stream_ordering: The lowest stream ordering which will
-                                  not be deleted.
-        """
-        txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate,
-            (room_id, user_id),
-        )
-
-        # We need to join on the events table to get the received_ts for
-        # event_push_actions and sqlite won't let us use a join in a delete so
-        # we can't just delete where received_ts < x. Furthermore we can
-        # only identify event_push_actions by a tuple of room_id, event_id
-        # we we can't use a subquery.
-        # Instead, we look up the stream ordering for the last event in that
-        # room received before the threshold time and delete event_push_actions
-        # in the room with a stream_odering before that.
-        txn.execute(
-            "DELETE FROM event_push_actions "
-            " WHERE user_id = ? AND room_id = ? AND "
-            " stream_ordering <= ?"
-            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
-            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
-        )
-
-        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
-            txn,
-            table="event_push_summary_stream_ordering",
-            keyvalues={},
-            retcol="stream_ordering",
-        )
-
-        notif_count, unread_count = self._get_notif_unread_count_for_user_room(
-            txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
-        )
-
-        self.db_pool.simple_upsert_txn(
-            txn,
-            table="event_push_summary",
-            keyvalues={"room_id": room_id, "user_id": user_id},
-            values={
-                "notif_count": notif_count,
-                "unread_count": unread_count,
-                "stream_ordering": old_rotate_stream_ordering,
-            },
-        )
-
 
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index bec6d60577..0090c9f225 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -26,7 +26,7 @@ from typing import (
     cast,
 )
 
-from synapse.api.constants import EduTypes, ReceiptTypes
+from synapse.api.constants import EduTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -682,17 +682,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
             lock=False,
         )
 
-        # When updating a local users read receipt, remove any push actions
-        # which resulted from the receipt's event and all earlier events.
-        if (
-            self.hs.is_mine_id(user_id)
-            and receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
-            and stream_ordering is not None
-        ):
-            self._remove_old_push_actions_before_txn(  # type: ignore[attr-defined]
-                txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
-            )
-
         return rx_ts
 
     def _graph_to_linear(
diff --git a/synapse/storage/schema/main/delta/72/01event_push_summary_receipt.sql b/synapse/storage/schema/main/delta/72/01event_push_summary_receipt.sql
new file mode 100644
index 0000000000..e45db61529
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/01event_push_summary_receipt.sql
@@ -0,0 +1,35 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a column that records the position of the read receipt for the user at
+-- the time we summarised the push actions. This is used to check if the counts
+-- are up to date after a new read receipt has been sent.
+--
+-- Null means that we can skip that check, as the row was written by an older
+-- version of Synapse that updated `event_push_summary` synchronously when
+-- persisting a new read receipt
+ALTER TABLE event_push_summary ADD COLUMN last_receipt_stream_ordering BIGINT;
+
+
+-- Tracks which new receipts we've handled
+CREATE TABLE event_push_summary_last_receipt_stream_id (
+    Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE,  -- Makes sure this table only has one row.
+    stream_id BIGINT NOT NULL,
+    CHECK (Lock='X')
+);
+
+INSERT INTO event_push_summary_last_receipt_stream_id (stream_id)
+  SELECT COALESCE(MAX(stream_id), 0)
+  FROM receipts_linearized;
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 2ac5f6db5e..ef069a8110 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -55,7 +55,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
 
     def test_count_aggregation(self) -> None:
         room_id = "!foo:example.com"
-        user_id = "@user1235:example.com"
+        user_id = "@user1235:test"
 
         last_read_stream_ordering = [0]
 
@@ -81,11 +81,26 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         def _inject_actions(stream: int, action: list) -> None:
             event = Mock()
             event.room_id = room_id
-            event.event_id = "$test:example.com"
+            event.event_id = f"$test{stream}:example.com"
             event.internal_metadata.stream_ordering = stream
             event.internal_metadata.is_outlier.return_value = False
             event.depth = stream
 
+            self.get_success(
+                self.store.db_pool.simple_insert(
+                    table="events",
+                    values={
+                        "stream_ordering": stream,
+                        "topological_ordering": stream,
+                        "type": "m.room.message",
+                        "room_id": room_id,
+                        "processed": True,
+                        "outlier": False,
+                        "event_id": event.event_id,
+                    },
+                )
+            )
+
             self.get_success(
                 self.store.add_push_actions_to_staging(
                     event.event_id,
@@ -105,18 +120,28 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         def _rotate(stream: int) -> None:
             self.get_success(
                 self.store.db_pool.runInteraction(
-                    "", self.store._rotate_notifs_before_txn, stream
+                    "rotate-receipts", self.store._handle_new_receipts_for_notifs_txn
+                )
+            )
+
+            self.get_success(
+                self.store.db_pool.runInteraction(
+                    "rotate-notifs", self.store._rotate_notifs_before_txn, stream
                 )
             )
 
         def _mark_read(stream: int, depth: int) -> None:
             last_read_stream_ordering[0] = stream
+
             self.get_success(
                 self.store.db_pool.runInteraction(
                     "",
-                    self.store._remove_old_push_actions_before_txn,
+                    self.store._insert_linearized_receipt_txn,
                     room_id,
+                    "m.read",
                     user_id,
+                    f"$test{stream}:example.com",
+                    {},
                     stream,
                 )
             )
@@ -150,7 +175,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
 
         _assert_counts(1, 0)
 
-        _mark_read(7, 7)
+        _mark_read(6, 6)
         _assert_counts(0, 0)
 
         _inject_actions(8, HIGHLIGHT)