Allow HTTP replication between workers in tests

pull/8433/head
Erik Johnston 2020-09-30 17:23:45 +01:00
parent 7941372ec8
commit 2de676cee3
2 changed files with 50 additions and 14 deletions

View File

@ -27,7 +27,7 @@ from synapse.app.generic_worker import (
GenericWorkerServer, GenericWorkerServer,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource, streams from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
@ -202,14 +202,20 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
self._worker_hs_to_resource = {} # A map from a HS instance to the associated HTTP Site to use for
# handling inbound HTTP requests to that instance.
self._hs_to_site = {self.hs: self.site}
# When we see a connection attempt to the master replication listener we # When we see a connection attempt to the replication listener on a HS
# automatically set up the connection. This is so that tests don't # we automatically set up the connection. This is so that tests don't
# manually have to go and explicitly set it up each time (plus sometimes # manually have to go and explicitly set it up each time (plus sometimes
# it is impossible to write the handling explicitly in the tests). # it is impossible to write the handling explicitly in the tests).
#
# This sets registers the master replication listener:
self.reactor.add_tcp_client_callback( self.reactor.add_tcp_client_callback(
"1.2.3.4", 8765, self._handle_http_replication_attempt "1.2.3.4",
8765,
lambda: self._handle_http_replication_attempt(self.hs, 8765),
) )
def create_test_json_resource(self): def create_test_json_resource(self):
@ -253,9 +259,31 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
**kwargs **kwargs
) )
# If the instance is in the `instance_map` config then workers may try
# and send HTTP requests to it, so we register it with
# `_handle_http_replication_attempt` like we do with the master HS.
instance_name = worker_hs.get_instance_name()
instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
if instance_loc:
# Ensure the host is one that has a fake DNS entry.
if instance_loc.host not in self.reactor.lookups:
raise Exception(
"Host does not have an IP for instance_map[%r].host = %r"
% (instance_name, instance_loc.host,)
)
self.reactor.add_tcp_client_callback(
self.reactor.lookups[instance_loc.host],
instance_loc.port,
lambda: self._handle_http_replication_attempt(
worker_hs, instance_loc.port
),
)
store = worker_hs.get_datastore() store = worker_hs.get_datastore()
store.db_pool._db_pool = self.database_pool._db_pool store.db_pool._db_pool = self.database_pool._db_pool
# Set up TCP replication between master and the new worker.
repl_handler = ReplicationCommandHandler(worker_hs) repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol( client = ClientReplicationStreamProtocol(
worker_hs, "client", "test", self.clock, repl_handler, worker_hs, "client", "test", self.clock, repl_handler,
@ -269,12 +297,20 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
server.makeConnection(server_transport) server.makeConnection(server_transport)
# Set up a resource for the worker # Set up a resource for the worker
resource = ReplicationRestResource(self.hs) resource = ReplicationRestResource(worker_hs)
for servlet in self.servlets: for servlet in self.servlets:
servlet(worker_hs, resource) servlet(worker_hs, resource)
self._worker_hs_to_resource[worker_hs] = resource self._hs_to_site[worker_hs] = SynapseSite(
logger_name="synapse.access.http.fake",
site_tag="{}-{}".format(
worker_hs.config.server.server_name, worker_hs.get_instance_name()
),
config=worker_hs.config.server.listeners[0],
resource=resource,
server_version_string="1",
)
return worker_hs return worker_hs
@ -285,7 +321,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return config return config
def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
render(request, self._worker_hs_to_resource[worker_hs], self.reactor) render(request, self._hs_to_site[worker_hs].resource, self.reactor)
def replicate(self): def replicate(self):
"""Tell the master side of replication that something has happened, and then """Tell the master side of replication that something has happened, and then
@ -294,9 +330,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke() self.streamer.on_notifier_poke()
self.pump() self.pump()
def _handle_http_replication_attempt(self): def _handle_http_replication_attempt(self, hs, repl_port):
"""Handles a connection attempt to the master replication HTTP """Handles a connection attempt to the given HS replication HTTP
listener. listener on the given port.
""" """
# We should have at least one outbound connection attempt, where the # We should have at least one outbound connection attempt, where the
@ -305,7 +341,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(len(clients), 1) self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop() (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, "1.2.3.4") self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 8765) self.assertEqual(port, repl_port)
# Set up client side protocol # Set up client side protocol
client_protocol = client_factory.buildProtocol(None) client_protocol = client_factory.buildProtocol(None)
@ -315,7 +351,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up the server side protocol # Set up the server side protocol
channel = _PushHTTPChannel(self.reactor) channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory channel.requestFactory = request_factory
channel.site = self.site channel.site = self._hs_to_site[hs]
# Connect client to server and vice versa. # Connect client to server and vice versa.
client_to_server_transport = FakeTransport( client_to_server_transport = FakeTransport(

View File

@ -241,7 +241,7 @@ class HomeserverTestCase(TestCase):
# create a site to wrap the resource. # create a site to wrap the resource.
self.site = SynapseSite( self.site = SynapseSite(
logger_name="synapse.access.http.fake", logger_name="synapse.access.http.fake",
site_tag="test", site_tag=self.hs.config.server.server_name,
config=self.hs.config.server.listeners[0], config=self.hs.config.server.listeners[0],
resource=self.resource, resource=self.resource,
server_version_string="1", server_version_string="1",