From 3a1f3f838862cfd2486773079d01e30c2237e8c2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 27 Mar 2017 16:32:07 +0100 Subject: [PATCH] Change slave storage to use new replication interface As the TCP replication uses a slightly different API and streams than the HTTP replication. This breaks HTTP replication. --- synapse/replication/slave/storage/_base.py | 31 +++------- .../replication/slave/storage/account_data.py | 49 ++++++---------- .../replication/slave/storage/deviceinbox.py | 23 ++++---- synapse/replication/slave/storage/devices.py | 24 ++++---- synapse/replication/slave/storage/events.py | 57 ++++++------------- synapse/replication/slave/storage/presence.py | 19 +++---- .../replication/slave/storage/push_rule.py | 23 ++++---- synapse/replication/slave/storage/pushers.py | 16 ++---- synapse/replication/slave/storage/receipts.py | 24 ++++---- synapse/replication/slave/storage/room.py | 11 ++-- tests/replication/slave/storage/_base.py | 30 +++++++--- 11 files changed, 128 insertions(+), 179 deletions(-) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index ab133db872..b962641166 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -15,7 +15,6 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine -from twisted.internet import defer from ._slaved_id_tracker import SlavedIdTracker @@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore): else: self._cache_id_gen = None - self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache" - self.http_client = hs.get_simple_http_client() + self.hs = hs def stream_positions(self): pos = {} @@ -43,35 +41,20 @@ class BaseSlavedStore(SQLBaseStore): pos["caches"] = self._cache_id_gen.get_current_token() return pos - def process_replication(self, result): - stream = result.get("caches") - if stream: - for row in stream["rows"]: - ( - position, cache_func, keys, invalidation_ts, - ) = row - + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "caches": + self._cache_id_gen.advance(token) + for row in rows: try: - getattr(self, cache_func).invalidate(tuple(keys)) + getattr(self, row.cache_func).invalidate(tuple(row.keys)) except AttributeError: # We probably haven't pulled in the cache in this worker, # which is fine. pass - self._cache_id_gen.advance(int(stream["position"])) - return defer.succeed(None) def _invalidate_cache_and_stream(self, txn, cache_func, keys): txn.call_after(cache_func.invalidate, keys) txn.call_after(self._send_invalidation_poke, cache_func, keys) - @defer.inlineCallbacks def _send_invalidation_poke(self, cache_func, keys): - try: - yield self.http_client.post_json_get_json(self.expire_cache_url, { - "invalidate": [{ - "name": cache_func.__name__, - "keys": list(keys), - }] - }) - except: - logger.exception("Failed to poke on expire_cache") + self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys) diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 77c64722c7..efbd87918e 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -69,38 +69,25 @@ class SlavedAccountDataStore(BaseSlavedStore): result["tag_account_data"] = position return result - def process_replication(self, result): - stream = result.get("user_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id, data_type = row[:3] - self.get_global_account_data_by_type_for_user.invalidate( - (data_type, user_id,) - ) - self.get_account_data_for_user.invalidate((user_id,)) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "tag_account_data": + self._account_data_id_gen.advance(token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - stream = result.get("room_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] - self.get_account_data_for_user.invalidate((user_id,)) + elif stream_name == "account_data": + self._account_data_id_gen.advance(token) + for row in rows: + if not row.room_id: + self.get_global_account_data_by_type_for_user.invalidate( + (row.data_type, row.user_id,) + ) + self.get_account_data_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - stream = result.get("tag_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] - self.get_tags_for_user.invalidate((user_id,)) - self._account_data_stream_cache.entity_has_changed( - user_id, position - ) - - return super(SlavedAccountDataStore, self).process_replication(result) + return super(SlavedAccountDataStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index f9102e0d89..6f3fb64770 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -53,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore): result["to_device"] = self._device_inbox_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("to_device") - if stream: - self._device_inbox_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - stream_id = row[0] - entity = row[1] - - if entity.startswith("@"): + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "to_device": + self._device_inbox_id_gen.advance(token) + for row in rows: + if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( - entity, stream_id + row.entity, token ) else: self._device_federation_outbox_stream_cache.entity_has_changed( - entity, stream_id + row.entity, token ) - - return super(SlavedDeviceInboxStore, self).process_replication(result) + return super(SlavedDeviceInboxStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index ca46aa17b6..4d4a435471 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -51,22 +51,18 @@ class SlavedDeviceStore(BaseSlavedStore): result["device_lists"] = self._device_list_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("device_lists") - if stream: - self._device_list_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - stream_id = row[0] - user_id = row[1] - destination = row[2] - + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "device_lists": + self._device_list_id_gen.advance(token) + for row in rows: self._device_list_stream_cache.entity_has_changed( - user_id, stream_id + row.user_id, token ) - if destination: + if row.destination: self._device_list_federation_stream_cache.entity_has_changed( - destination, stream_id + row.destination, token ) - - return super(SlavedDeviceStore, self).process_replication(result) + return super(SlavedDeviceStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d4db1e452e..5fd47706ef 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -201,48 +201,25 @@ class SlavedEventStore(BaseSlavedStore): result["backfill"] = -self._backfill_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("events") - if stream: - self._stream_id_gen.advance(int(stream["position"])) - - if stream["rows"]: - logger.info("Got %d event rows", len(stream["rows"])) - - for row in stream["rows"]: - self._process_replication_row( - row, backfilled=False, + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "events": + self._stream_id_gen.advance(token) + for row in rows: + self.invalidate_caches_for_event( + token, row.event_id, row.room_id, row.type, row.state_key, + row.redacts, + backfilled=False, ) - - stream = result.get("backfill") - if stream: - self._backfill_id_gen.advance(-int(stream["position"])) - for row in stream["rows"]: - self._process_replication_row( - row, backfilled=True, + elif stream_name == "backfill": + self._backfill_id_gen.advance(-token) + for row in rows: + self.invalidate_caches_for_event( + -token, row.event_id, row.room_id, row.type, row.state_key, + row.redacts, + backfilled=True, ) - - stream = result.get("forward_ex_outliers") - if stream: - self._stream_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - event_id = row[1] - self._invalidate_get_event_cache(event_id) - - stream = result.get("backward_ex_outliers") - if stream: - self._backfill_id_gen.advance(-int(stream["position"])) - for row in stream["rows"]: - event_id = row[1] - self._invalidate_get_event_cache(event_id) - - return super(SlavedEventStore, self).process_replication(result) - - def _process_replication_row(self, row, backfilled): - stream_ordering = row[0] if not backfilled else -row[0] - self.invalidate_caches_for_event( - stream_ordering, row[1], row[2], row[3], row[4], row[5], - backfilled=backfilled, + return super(SlavedEventStore, self).process_replication_rows( + stream_name, token, rows ) def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index e4a2414d78..dffc80adc3 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -48,15 +48,14 @@ class SlavedPresenceStore(BaseSlavedStore): result["presence"] = position return result - def process_replication(self, result): - stream = result.get("presence") - if stream: - self._presence_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "presence": + self._presence_id_gen.advance(token) + for row in rows: self.presence_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - self._get_presence_for_user.invalidate((user_id,)) - - return super(SlavedPresenceStore, self).process_replication(result) + self._get_presence_for_user.invalidate((row.user_id,)) + return super(SlavedPresenceStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 21ceb0213a..83e880fdd2 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -50,18 +50,15 @@ class SlavedPushRuleStore(SlavedEventStore): result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("push_rules") - if stream: - for row in stream["rows"]: - position = row[0] - user_id = row[2] - self.get_push_rules_for_user.invalidate((user_id,)) - self.get_push_rules_enabled_for_user.invalidate((user_id,)) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "push_rules": + self._push_rules_stream_id_gen.advance(token) + for row in rows: + self.get_push_rules_for_user.invalidate((row.user_id,)) + self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.push_rules_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - self._push_rules_stream_id_gen.advance(int(stream["position"])) - - return super(SlavedPushRuleStore, self).process_replication(result) + return super(SlavedPushRuleStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index d88206b3bb..4e8d68ece9 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -40,13 +40,9 @@ class SlavedPusherStore(BaseSlavedStore): result["pushers"] = self._pushers_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("pushers") - if stream: - self._pushers_id_gen.advance(int(stream["position"])) - - stream = result.get("deleted_pushers") - if stream: - self._pushers_id_gen.advance(int(stream["position"])) - - return super(SlavedPusherStore, self).process_replication(result) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "pushers": + self._pushers_id_gen.advance(token) + return super(SlavedPusherStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index ac9662d399..b371574ece 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -65,20 +65,22 @@ class SlavedReceiptsStore(BaseSlavedStore): result["receipts"] = self._receipts_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("receipts") - if stream: - self._receipts_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, room_id, receipt_type, user_id = row[:4] - self.invalidate_caches_for_receipt(room_id, receipt_type, user_id) - self._receipts_stream_cache.entity_has_changed(room_id, position) - - return super(SlavedReceiptsStore, self).process_replication(result) - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_linearized_receipts_for_room.invalidate_many((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) ) + + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "receipts": + self._receipts_id_gen.advance(token) + for row in rows: + self.invalidate_caches_for_receipt( + row.room_id, row.receipt_type, row.user_id + ) + self._receipts_stream_cache.entity_has_changed(row.room_id, token) + + return super(SlavedReceiptsStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 6df9a25ef3..f510384033 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -46,9 +46,10 @@ class RoomStore(BaseSlavedStore): result["public_rooms"] = self._public_room_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("public_rooms") - if stream: - self._public_room_id_gen.advance(int(stream["position"])) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "public_rooms": + self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication(result) + return super(RoomStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index b82868054d..81063f19a1 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from twisted.internet import defer, reactor from tests import unittest from mock import Mock, NonCallableMock from tests.utils import setup_test_homeserver -from synapse.replication.resource import ReplicationResource +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.replication.tcp.client import ( + ReplicationClientHandler, ReplicationClientFactory, +) class BaseSlavedStoreTestCase(unittest.TestCase): @@ -33,18 +36,29 @@ class BaseSlavedStoreTestCase(unittest.TestCase): ) self.hs.get_ratelimiter().send_message.return_value = (True, 0) - self.replication = ReplicationResource(self.hs) - self.master_store = self.hs.get_datastore() self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.event_id = 0 + server_factory = ReplicationStreamProtocolFactory(self.hs) + listener = reactor.listenUNIX("\0xxx", server_factory) + self.addCleanup(listener.stopListening) + self.streamer = server_factory.streamer + + self.replication_handler = ReplicationClientHandler(self.slaved_store) + client_factory = ReplicationClientFactory( + self.hs, "client_name", self.replication_handler + ) + client_connector = reactor.connectUNIX("\0xxx", client_factory) + self.addCleanup(client_factory.stopTrying) + self.addCleanup(client_connector.disconnect) + @defer.inlineCallbacks def replicate(self): - streams = self.slaved_store.stream_positions() - writer = yield self.replication.replicate(streams, 100) - result = writer.finish() - yield self.slaved_store.process_replication(result) + yield self.streamer.on_notifier_poke() + d = self.replication_handler.await_sync("replication_test") + self.streamer.send_sync_to_all_connections("replication_test") + yield d @defer.inlineCallbacks def check(self, method, args, expected_result=None):