From f5ac4dc2d46d329e7053259c61ad402269903ee3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 16 Feb 2018 12:17:14 +0000 Subject: [PATCH 1/4] Split ReceiptsStore --- synapse/replication/slave/storage/receipts.py | 33 +----- synapse/storage/__init__.py | 3 - synapse/storage/receipts.py | 109 ++++++++++-------- 3 files changed, 69 insertions(+), 76 deletions(-) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index b371574ece..4e845ec041 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,9 +17,7 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.receipts import ReceiptsStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.storage.receipts import ReceiptsWorkerStore # So, um, we want to borrow a load of functions intended for reading from # a DataStore, but we don't want to take functions that either write to the @@ -29,36 +28,14 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache # the method descriptor on the DataStore and chuck them into our class. -class SlavedReceiptsStore(BaseSlavedStore): +class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedReceiptsStore, self).__init__(db_conn, hs) - - self._receipts_id_gen = SlavedIdTracker( + receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() - ) - - get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"] - get_linearized_receipts_for_room = ( - ReceiptsStore.__dict__["get_linearized_receipts_for_room"] - ) - _get_linearized_receipts_for_rooms = ( - ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"] - ) - get_last_receipt_event_id_for_user = ( - ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"] - ) - - get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__ - get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__ - - get_linearized_receipts_for_rooms = ( - DataStore.get_linearized_receipts_for_rooms.__func__ - ) + super(SlavedReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs) def stream_positions(self): result = super(SlavedReceiptsStore, self).stream_positions() diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f8fbd02ceb..e1c4fe086e 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -104,9 +104,6 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "events", "stream_ordering", step=-1, extra_tables=[("ex_outlier_stream", "event_stream_ordering")] ) - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 12b3cc7f5f..aa62474a46 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +15,7 @@ # limitations under the License. from ._base import SQLBaseStore +from .util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -26,9 +28,17 @@ import ujson as json logger = logging.getLogger(__name__) -class ReceiptsStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(ReceiptsStore, self).__init__(db_conn, hs) +class ReceiptsWorkerStore(SQLBaseStore): + def __init__(self, receipts_id_gen, db_conn, hs): + """ + Args: + receipts_id_gen (StreamIdGenerator|SlavedIdTracker) + db_conn: Database connection + hs (Homeserver) + """ + super(ReceiptsWorkerStore, self).__init__(db_conn, hs) + + self._receipts_id_gen = receipts_id_gen self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() @@ -39,26 +49,6 @@ class ReceiptsStore(SQLBaseStore): receipts = yield self.get_receipts_for_room(room_id, "m.read") defer.returnValue(set(r['user_id'] for r in receipts)) - def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, - user_id): - if receipt_type != "m.read": - return - - # Returns an ObservableDeferred - res = self.get_users_with_read_receipts_in_room.cache.get( - room_id, None, update_metrics=False, - ) - - if res: - if isinstance(res, defer.Deferred) and res.called: - res = res.result - if user_id in res: - # We'd only be adding to the set, so no point invalidating if the - # user is already there - return - - self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): return self._simple_select_list( @@ -273,6 +263,57 @@ class ReceiptsStore(SQLBaseStore): def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() + def get_all_updated_receipts(self, last_id, current_id, limit=None): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_receipts_txn(txn): + sql = ( + "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" + " FROM receipts_linearized" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + ) + args = [last_id, current_id] + if limit is not None: + sql += " LIMIT ?" + args.append(limit) + txn.execute(sql, args) + + return txn.fetchall() + return self.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) + + +class ReceiptsStore(ReceiptsWorkerStore): + def __init__(self, db_conn, hs): + receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + + super(ReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs) + + def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, + user_id): + if receipt_type != "m.read": + return + + # Returns an ObservableDeferred + res = self.get_users_with_read_receipts_in_room.cache.get( + room_id, None, update_metrics=False, + ) + + if res: + if isinstance(res, defer.Deferred) and res.called: + res = res.result + if user_id in res: + # We'd only be adding to the set, so no point invalidating if the + # user is already there + return + + self.get_users_with_read_receipts_in_room.invalidate((room_id,)) + def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): txn.call_after( @@ -457,25 +498,3 @@ class ReceiptsStore(SQLBaseStore): "data": json.dumps(data), } ) - - def get_all_updated_receipts(self, last_id, current_id, limit=None): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_receipts_txn(txn): - sql = ( - "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" - " FROM receipts_linearized" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - ) - args = [last_id, current_id] - if limit is not None: - sql += " LIMIT ?" - args.append(limit) - txn.execute(sql, args) - - return txn.fetchall() - return self.runInteraction( - "get_all_updated_receipts", get_all_updated_receipts_txn - ) From e316bbb4c07cad97c4cff5bc0c5b0dc2cd7bc519 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 20 Feb 2018 17:33:18 +0000 Subject: [PATCH 2/4] Use abstract base class to access stream IDs --- synapse/replication/slave/storage/receipts.py | 9 +++- synapse/storage/receipts.py | 42 ++++++++++++------- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 4e845ec041..a2eb4a02db 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -31,11 +31,16 @@ from synapse.storage.receipts import ReceiptsWorkerStore class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): - receipts_id_gen = SlavedIdTracker( + # We instansiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id + self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - super(SlavedReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs) + super(SlavedReceiptsStore, self).__init__(db_conn, hs) + + def get_max_receipt_stream_id(self): + return self._receipts_id_gen.get_current_token() def stream_positions(self): result = super(SlavedReceiptsStore, self).stream_positions() diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index aa62474a46..b11cf7ff62 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -21,6 +21,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache from twisted.internet import defer +import abc import logging import ujson as json @@ -29,21 +30,30 @@ logger = logging.getLogger(__name__) class ReceiptsWorkerStore(SQLBaseStore): - def __init__(self, receipts_id_gen, db_conn, hs): - """ - Args: - receipts_id_gen (StreamIdGenerator|SlavedIdTracker) - db_conn: Database connection - hs (Homeserver) - """ + """This is an abstract base class where subclasses must implement + `get_max_receipt_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): super(ReceiptsWorkerStore, self).__init__(db_conn, hs) - self._receipts_id_gen = receipts_id_gen - self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() + "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) + @abc.abstractmethod + def get_max_receipt_stream_id(self): + """Get the current max stream ID for receipts stream + + Returns: + int + """ + pass + @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") @@ -260,9 +270,6 @@ class ReceiptsWorkerStore(SQLBaseStore): } defer.returnValue(results) - def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_current_token() - def get_all_updated_receipts(self, last_id, current_id, limit=None): if last_id == current_id: return defer.succeed([]) @@ -288,11 +295,16 @@ class ReceiptsWorkerStore(SQLBaseStore): class ReceiptsStore(ReceiptsWorkerStore): def __init__(self, db_conn, hs): - receipts_id_gen = StreamIdGenerator( + # We instansiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id + self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) - super(ReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs) + super(ReceiptsStore, self).__init__(db_conn, hs) + + def get_max_receipt_stream_id(self): + return self._receipts_id_gen.get_current_token() def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, user_id): From 95e4cffd859b0fc3fcd54c755fcc0da403f97b94 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 20 Feb 2018 17:58:40 +0000 Subject: [PATCH 3/4] Fix comment --- synapse/replication/slave/storage/receipts.py | 2 +- synapse/storage/receipts.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index a2eb4a02db..f0e29e9836 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -31,7 +31,7 @@ from synapse.storage.receipts import ReceiptsWorkerStore class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): - # We instansiate this first as the ReceiptsWorkerStore constructor + # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index b11cf7ff62..c2a6613a62 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -295,7 +295,7 @@ class ReceiptsWorkerStore(SQLBaseStore): class ReceiptsStore(ReceiptsWorkerStore): def __init__(self, db_conn, hs): - # We instansiate this first as the ReceiptsWorkerStore constructor + # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" From 8fbb4d0d19031d8cd3742285fb9b36c3bdfc52a0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 20 Feb 2018 17:59:23 +0000 Subject: [PATCH 4/4] Raise exception in abstract method --- synapse/storage/receipts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index c2a6613a62..40530632c6 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -52,7 +52,7 @@ class ReceiptsWorkerStore(SQLBaseStore): Returns: int """ - pass + raise NotImplementedError() @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id):