Add redis support to replication tests

pull/8433/head
Erik Johnston 2020-09-30 18:45:55 +01:00
parent 2de676cee3
commit ff35d07e92
3 changed files with 203 additions and 16 deletions

View File

@ -251,10 +251,9 @@ class ReplicationCommandHandler:
using TCP.
"""
if hs.config.redis.redis_enabled:
import txredisapi
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
lazyConnection,
)
logger.info(
@ -271,7 +270,8 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
outbound_redis_connection = txredisapi.lazyConnection(
outbound_redis_connection = lazyConnection(
reactor=hs.get_reactor(),
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,

View File

@ -228,3 +228,44 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
p.password = self.password
return p
def lazyConnection(
reactor,
host="localhost",
port=6379,
dbid=None,
reconnect=True,
charset="utf-8",
password=None,
connectTimeout=None,
replyTimeout=None,
convertNumbers=True,
):
"""Equivalent to `txredisapi.lazyConnection`, except allows specifying a
reactor.
"""
isLazy = True
poolsize = 1
uuid = "%s:%d" % (host, port)
factory = txredisapi.RedisFactory(
uuid,
dbid,
poolsize,
isLazy,
txredisapi.ConnectionHandler,
charset,
password,
replyTimeout,
convertNumbers,
)
factory.continueTrying = reconnect
for x in range(poolsize):
reactor.connectTCP(host, port, factory, connectTimeout)
if isLazy:
return factory.handler
else:
return factory.deferred

View File

@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, List, Optional, Tuple
import attr
import hiredis
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
@ -197,17 +198,29 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
# Fake in memory Redis server that servers can connect to.
self._redis_server = FakeRedisPubSubServer()
store = self.hs.get_datastore()
self.database_pool = store.db_pool
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["localhost"] = "127.0.0.1"
# A map from a HS instance to the associated HTTP Site to use for
# handling inbound HTTP requests to that instance.
self._hs_to_site = {self.hs: self.site}
# When we see a connection attempt to the replication listener on a HS
# we automatically set up the connection. This is so that tests don't
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost", 6379, self.connect_any_redis_attempts,
)
self.hs.get_tcp_replication().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
# manually have to go and explicitly set it up each time (plus sometimes
# it is impossible to write the handling explicitly in the tests).
#
@ -283,18 +296,20 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
store = worker_hs.get_datastore()
store.db_pool._db_pool = self.database_pool._db_pool
# Set up TCP replication between master and the new worker.
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs, "client", "test", self.clock, repl_handler,
)
server = self.server_factory.buildProtocol(None)
# Set up TCP replication between master and the new worker if we don't
# have Redis support enabled.
if not worker_hs.config.redis_enabled:
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs, "client", "test", self.clock, repl_handler,
)
server = self.server_factory.buildProtocol(None)
client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)
client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)
server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport)
server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport)
# Set up a resource for the worker
resource = ReplicationRestResource(worker_hs)
@ -312,6 +327,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
server_version_string="1",
)
if worker_hs.config.redis.redis_enabled:
worker_hs.get_tcp_replication().start_replication(worker_hs)
return worker_hs
def _get_worker_hs_config(self) -> dict:
@ -369,6 +387,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
def connect_any_redis_attempts(self):
"""If redis is enabled we need to deal with workers connecting to a
redis server. We don't want to use a real Redis server so we use a
fake one.
"""
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)
client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, server_protocol
)
server_protocol.makeConnection(server_to_client_transport)
return client_to_server_transport, server_to_client_transport
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@ -503,3 +547,105 @@ class _PullToPushProducer:
pass
self.stopProducing()
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub.
"""
def __init__(self):
self._subscribers = set()
def add_subscriber(self, conn):
"""A connection has called SUBSCRIBE
"""
self._subscribers.add(conn)
def remove_subscriber(self, conn):
"""A connection has called UNSUBSCRIBE
"""
self._subscribers.discard(conn)
def publish(self, conn, channel, msg) -> int:
"""A connection want to publish a message to subscribers.
"""
for sub in self._subscribers:
sub.send(["message", channel, msg])
return len(self._subscribers)
def buildProtocol(self, addr):
return FakeRedisPubSubProtocol(self)
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server.
"""
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()
def dataReceived(self, data):
self._reader.feed(data)
# We might get multiple messages in one packet.
while True:
msg = self._reader.gets()
if msg is False:
# No more messages.
return
if not isinstance(msg, list):
# Inbound commands should always be a list
raise Exception("Expected redis list")
self.handle_command(msg[0], *msg[1:])
def handle_command(self, command, *args):
"""Received a Redis command from the client.
"""
# We currently only support pub/sub.
if command == b"PUBLISH":
channel, message = args
num_subscribers = self._server.publish(self, channel, message)
self.send(num_subscribers)
elif command == b"SUBSCRIBE":
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
else:
raise Exception("Unknown command")
def send(self, msg):
"""Send a message back to the client.
"""
raw = self.encode(msg).encode("utf-8")
self.transport.write(raw)
self.transport.flush()
def encode(self, obj):
"""Encode an object to its Redis format.
Supports: strings/bytes, integers and list/tuples.
"""
if isinstance(obj, bytes):
# We assume bytes are just unicode strings.
obj = obj.decode("utf-8")
if isinstance(obj, str):
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
if isinstance(obj, int):
return ":{val}\r\n".format(val=obj)
if isinstance(obj, (list, tuple)):
items = "".join(self.encode(a) for a in obj)
return "*{len}\r\n{items}".format(len=len(obj), items=items)
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
def connectionList(self, reason):
self._server.remove_subscriber(self)