From 6c1f7c722f0baade9aecf41f600fcced670c4fcb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 15 May 2020 19:03:25 +0100 Subject: [PATCH] Fix limit logic for AccountDataStream (#7384) Make sure that the AccountDataStream presents complete updates, in the right order. This is much the same fix as #7337 and #7358, but applied to a different stream. --- changelog.d/7384.bugfix | 1 + synapse/replication/tcp/streams/_base.py | 68 ++++++++-- .../storage/data_stores/main/account_data.py | 68 ++++++---- .../tcp/streams/test_account_data.py | 117 ++++++++++++++++++ 4 files changed, 220 insertions(+), 34 deletions(-) create mode 100644 changelog.d/7384.bugfix create mode 100644 tests/replication/tcp/streams/test_account_data.py diff --git a/changelog.d/7384.bugfix b/changelog.d/7384.bugfix new file mode 100644 index 0000000000..f49c600173 --- /dev/null +++ b/changelog.d/7384.bugfix @@ -0,0 +1 @@ +Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index b48a6a3e91..d42aaff055 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,14 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import heapq import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + Optional, + Tuple, + TypeVar, +) import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) # the number of rows to request from an update_function. @@ -37,7 +50,7 @@ Token = int # parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's # just a row from a database query, though this is dependent on the stream in question. # -StreamRow = Tuple +StreamRow = TypeVar("StreamRow", bound=Tuple) # The type returned by the update_function of a stream, as well as get_updates(), # get_updates_since, etc. @@ -533,32 +546,63 @@ class AccountDataStream(Stream): """ AccountDataStreamRow = namedtuple( - "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str + "AccountDataStream", + ("user_id", "room_id", "data_type"), # str # Optional[str] # str ) NAME = "account_data" ROW_TYPE = AccountDataStreamRow - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), current_token_without_instance(self.store.get_max_account_data_stream_id), - db_query_to_update_function(self._update_function), + self._update_function, ) - async def _update_function(self, from_token, to_token, limit): - global_results, room_results = await self.store.get_all_updated_account_data( - from_token, from_token, to_token, limit + async def _update_function( + self, instance_name: str, from_token: int, to_token: int, limit: int + ) -> StreamUpdateResult: + limited = False + global_results = await self.store.get_updated_global_account_data( + from_token, to_token, limit ) - results = list(room_results) - results.extend( - (stream_id, user_id, None, account_data_type) + # if the global results hit the limit, we'll need to limit the room results to + # the same stream token. + if len(global_results) >= limit: + to_token = global_results[-1][0] + limited = True + + room_results = await self.store.get_updated_room_account_data( + from_token, to_token, limit + ) + + # likewise, if the room results hit the limit, limit the global results to + # the same stream token. + if len(room_results) >= limit: + to_token = room_results[-1][0] + limited = True + + # convert the global results to the right format, and limit them to the to_token + # at the same time + global_rows = ( + (stream_id, (user_id, None, account_data_type)) for stream_id, user_id, account_data_type in global_results + if stream_id <= to_token ) - return results + # we know that the room_results are already limited to `to_token` so no need + # for a check on `stream_id` here. + room_rows = ( + (stream_id, (user_id, room_id, account_data_type)) + for stream_id, user_id, room_id, account_data_type in room_results + ) + + # we need to return a sorted list, so merge them together. + updates = list(heapq.merge(room_rows, global_rows)) + return updates, to_token, limited class GroupServerStream(Stream): diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index 46b494b334..f9eef1b78e 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -16,6 +16,7 @@ import abc import logging +from typing import List, Tuple from canonicaljson import json @@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore): "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) - def get_all_updated_account_data( - self, last_global_id, last_room_id, current_id, limit - ): - """Get all the client account_data that has changed on the server - Args: - last_global_id(int): The position to fetch from for top level data - last_room_id(int): The position to fetch from for per room data - current_id(int): The position to fetch up to. - Returns: - A deferred pair of lists of tuples of stream_id int, user_id string, - room_id string, and type string. - """ - if last_room_id == current_id and last_global_id == current_id: - return defer.succeed(([], [])) + async def get_updated_global_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str]]: + """Get the global account_data that has changed, for the account_data stream - def get_updated_account_data_txn(txn): + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: + A list of tuples of stream_id int, user_id string, + and type string. + """ + if last_id == current_id: + return [] + + def get_updated_global_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_global_id, current_id, limit)) - global_results = txn.fetchall() + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + return await self.db.runInteraction( + "get_updated_global_account_data", get_updated_global_account_data_txn + ) + + async def get_updated_room_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str, str]]: + """Get the global account_data that has changed, for the account_data stream + + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: + A list of tuples of stream_id int, user_id string, + room_id string and type string. + """ + if last_id == current_id: + return [] + + def get_updated_room_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_room_id, current_id, limit)) - room_results = txn.fetchall() - return global_results, room_results + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() - return self.db.runInteraction( - "get_all_updated_account_data_txn", get_updated_account_data_txn + return await self.db.runInteraction( + "get_updated_room_account_data", get_updated_room_account_data_txn ) def get_updated_account_data_for_user(self, user_id, stream_id): diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py new file mode 100644 index 0000000000..6a5116dd2a --- /dev/null +++ b/tests/replication/tcp/streams/test_account_data.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.replication.tcp.streams._base import ( + _STREAM_UPDATE_TARGET_ROW_COUNT, + AccountDataStream, +) + +from tests.replication._base import BaseStreamTestCase + + +class AccountDataStreamTestCase(BaseStreamTestCase): + def test_update_function_room_account_data_limit(self): + """Test replication with many room account data updates + """ + store = self.hs.get_datastore() + + # generate lots of account data updates + updates = [] + for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5): + update = "m.test_type.%i" % (i,) + self.get_success( + store.add_account_data_to_room("test_user", "test_room", update, {}) + ) + updates.append(update) + + # also one global update + self.get_success(store.add_account_data_for_user("test_user", "m.global", {})) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + + for t in updates: + (stream_name, token, row) = received_rows.pop(0) + self.assertEqual(stream_name, AccountDataStream.NAME) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, t) + self.assertEqual(row.room_id, "test_room") + + (stream_name, token, row) = received_rows.pop(0) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, "m.global") + self.assertIsNone(row.room_id) + + self.assertEqual([], received_rows) + + def test_update_function_global_account_data_limit(self): + """Test replication with many global account data updates + """ + store = self.hs.get_datastore() + + # generate lots of account data updates + updates = [] + for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5): + update = "m.test_type.%i" % (i,) + self.get_success(store.add_account_data_for_user("test_user", update, {})) + updates.append(update) + + # also one per-room update + self.get_success( + store.add_account_data_to_room("test_user", "test_room", "m.per_room", {}) + ) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + + for t in updates: + (stream_name, token, row) = received_rows.pop(0) + self.assertEqual(stream_name, AccountDataStream.NAME) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, t) + self.assertIsNone(row.room_id) + + (stream_name, token, row) = received_rows.pop(0) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, "m.per_room") + self.assertEqual(row.room_id, "test_room") + + self.assertEqual([], received_rows)