From 90bd1708b5961b938eebd9c09658c3a39ad77c5e Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Tue, 31 Mar 2020 13:25:09 +0100
Subject: [PATCH] Fix test

---
 tests/replication/slave/storage/_base.py      | 20 ++++++----
 tests/replication/tcp/streams/_base.py        | 38 +++++++------------
 .../replication/tcp/streams/test_receipts.py  |  1 -
 3 files changed, 26 insertions(+), 33 deletions(-)

diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 2a1e7c7166..1deec91e07 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -15,10 +15,9 @@
 
 from mock import Mock, NonCallableMock
 
-from synapse.replication.tcp.client import (
-    ReplicationClientFactory,
-    ReplicationClientHandler,
-)
+from synapse.replication.tcp.client import ReplicationClientFactory
+from synapse.replication.tcp.handler import ReplicationCommandHandler
+from synapse.replication.tcp.client import ReplicationDataHandler
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.storage.database import make_conn
 
@@ -51,15 +50,20 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
         self.event_id = 0
 
         server_factory = ReplicationStreamProtocolFactory(self.hs)
-        self.streamer = server_factory.streamer
+        self.streamer = hs.get_replication_streamer()
 
-        handler_factory = Mock()
-        self.replication_handler = ReplicationClientHandler(self.slaved_store)
-        self.replication_handler.factory = handler_factory
+        # We now do some gut wrenching so that we have a client that is based
+        # off of the slave store rather than the main store.
+        self.replication_handler = ReplicationCommandHandler(self.hs)
+        self.replication_handler.store = self.slaved_store
+        self.replication_handler.replication_data_handler = ReplicationDataHandler(
+            self.slaved_store
+        )
 
         client_factory = ReplicationClientFactory(
             self.hs, "client_name", self.replication_handler
         )
+        client_factory.handler = self.replication_handler
 
         server = server_factory.buildProtocol(None)
         client = client_factory.buildProtocol(None)
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index a755fe2879..920ddb0c8b 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -15,7 +15,7 @@
 
 from mock import Mock
 
-from synapse.replication.tcp.commands import ReplicateCommand
+from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 
@@ -26,15 +26,20 @@ from tests.server import FakeTransport
 class BaseStreamTestCase(unittest.HomeserverTestCase):
     """Base class for tests of the replication streams"""
 
+    def make_homeserver(self, reactor, clock):
+        self.test_handler = Mock(wraps=TestReplicationClientHandler())
+        return self.setup_test_homeserver(replication_data_handler=self.test_handler)
+
     def prepare(self, reactor, clock, hs):
         # build a replication server
-        server_factory = ReplicationStreamProtocolFactory(self.hs)
-        self.streamer = server_factory.streamer
+        server_factory = ReplicationStreamProtocolFactory(hs)
+        self.streamer = hs.get_replication_streamer()
         self.server = server_factory.buildProtocol(None)
 
-        self.test_handler = Mock(wraps=TestReplicationClientHandler())
+        repl_handler = ReplicationCommandHandler(hs)
+        repl_handler.handler = self.test_handler
         self.client = ClientReplicationStreamProtocol(
-            hs, "client", "test", clock, self.test_handler,
+            hs, "client", "test", clock, repl_handler,
         )
 
         self._client_transport = None
@@ -69,14 +74,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.streamer.on_notifier_poke()
         self.pump(0.1)
 
-    def replicate_stream(self):
-        """Make the client end a REPLICATE command to set up a subscription to a stream"""
-        self.client.send_command(ReplicateCommand())
-
-
-class TestReplicationClientHandler(object):
-    """Drop-in for ReplicationClientHandler which just collects RDATA rows"""
 
+class TestReplicationClientHandler:
     def __init__(self):
         self.streams = set()
         self._received_rdata_rows = []
@@ -88,18 +87,9 @@ class TestReplicationClientHandler(object):
                 positions[stream] = max(token, positions.get(stream, 0))
         return positions
 
-    def get_currently_syncing_users(self):
-        return []
-
-    def update_connection(self, connection):
-        pass
-
-    def finished_connecting(self):
-        pass
-
-    async def on_position(self, stream_name, token):
-        """Called when we get new position data."""
-
     async def on_rdata(self, stream_name, token, rows):
         for r in rows:
             self._received_rdata_rows.append((stream_name, token, r))
+
+    async def on_position(self, stream_name, token):
+        pass
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 0ec0825a0e..a0206f7363 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -24,7 +24,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
         self.reconnect()
 
         # make the client subscribe to the receipts stream
-        self.replicate_stream()
         self.test_handler.streams.add("receipts")
 
         # tell the master to send a new receipt