Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic. (#12672)
parent
eb4aaa1b4b
commit
177b884ad7
|
@ -0,0 +1 @@
|
||||||
|
Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic.
|
|
@ -1,5 +1,5 @@
|
||||||
# Copyright 2017 Vector Creations Ltd
|
# Copyright 2017 Vector Creations Ltd
|
||||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
# Copyright 2020, 2022 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -101,6 +101,9 @@ class ReplicationCommandHandler:
|
||||||
self._instance_id = hs.get_instance_id()
|
self._instance_id = hs.get_instance_id()
|
||||||
self._instance_name = hs.get_instance_name()
|
self._instance_name = hs.get_instance_name()
|
||||||
|
|
||||||
|
# Additional Redis channel suffixes to subscribe to.
|
||||||
|
self._channels_to_subscribe_to: List[str] = []
|
||||||
|
|
||||||
self._is_presence_writer = (
|
self._is_presence_writer = (
|
||||||
hs.get_instance_name() in hs.config.worker.writers.presence
|
hs.get_instance_name() in hs.config.worker.writers.presence
|
||||||
)
|
)
|
||||||
|
@ -243,6 +246,31 @@ class ReplicationCommandHandler:
|
||||||
# If we're NOT using Redis, this must be handled by the master
|
# If we're NOT using Redis, this must be handled by the master
|
||||||
self._should_insert_client_ips = hs.get_instance_name() == "master"
|
self._should_insert_client_ips = hs.get_instance_name() == "master"
|
||||||
|
|
||||||
|
if self._is_master or self._should_insert_client_ips:
|
||||||
|
self.subscribe_to_channel("USER_IP")
|
||||||
|
|
||||||
|
def subscribe_to_channel(self, channel_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Indicates that we wish to subscribe to a Redis channel by name.
|
||||||
|
|
||||||
|
(The name will later be prefixed with the server name; i.e. subscribing
|
||||||
|
to the 'ABC' channel actually subscribes to 'example.com/ABC' Redis-side.)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
- If replication has already started, then it's too late to subscribe
|
||||||
|
to new channels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._factory is not None:
|
||||||
|
# We don't allow subscribing after the fact to avoid the chance
|
||||||
|
# of missing an important message because we didn't subscribe in time.
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot subscribe to more channels after replication started."
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel_name not in self._channels_to_subscribe_to:
|
||||||
|
self._channels_to_subscribe_to.append(channel_name)
|
||||||
|
|
||||||
def _add_command_to_stream_queue(
|
def _add_command_to_stream_queue(
|
||||||
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
|
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -321,7 +349,9 @@ class ReplicationCommandHandler:
|
||||||
|
|
||||||
# Now create the factory/connection for the subscription stream.
|
# Now create the factory/connection for the subscription stream.
|
||||||
self._factory = RedisDirectTcpReplicationClientFactory(
|
self._factory = RedisDirectTcpReplicationClientFactory(
|
||||||
hs, outbound_redis_connection
|
hs,
|
||||||
|
outbound_redis_connection,
|
||||||
|
channel_names=self._channels_to_subscribe_to,
|
||||||
)
|
)
|
||||||
hs.get_reactor().connectTCP(
|
hs.get_reactor().connectTCP(
|
||||||
hs.config.redis.redis_host,
|
hs.config.redis.redis_host,
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from inspect import isawaitable
|
from inspect import isawaitable
|
||||||
from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import txredisapi
|
import txredisapi
|
||||||
|
@ -85,14 +85,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
synapse_handler: The command handler to handle incoming commands.
|
synapse_handler: The command handler to handle incoming commands.
|
||||||
synapse_stream_name: The *redis* stream name to subscribe to and publish
|
synapse_stream_prefix: The *redis* stream name to subscribe to and publish
|
||||||
from (not anything to do with Synapse replication streams).
|
from (not anything to do with Synapse replication streams).
|
||||||
synapse_outbound_redis_connection: The connection to redis to use to send
|
synapse_outbound_redis_connection: The connection to redis to use to send
|
||||||
commands.
|
commands.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
synapse_handler: "ReplicationCommandHandler"
|
synapse_handler: "ReplicationCommandHandler"
|
||||||
synapse_stream_name: str
|
synapse_stream_prefix: str
|
||||||
|
synapse_channel_names: List[str]
|
||||||
synapse_outbound_redis_connection: txredisapi.ConnectionHandler
|
synapse_outbound_redis_connection: txredisapi.ConnectionHandler
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any):
|
def __init__(self, *args: Any, **kwargs: Any):
|
||||||
|
@ -117,8 +118,13 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||||
# it's important to make sure that we only send the REPLICATE command once we
|
# it's important to make sure that we only send the REPLICATE command once we
|
||||||
# have successfully subscribed to the stream - otherwise we might miss the
|
# have successfully subscribed to the stream - otherwise we might miss the
|
||||||
# POSITION response sent back by the other end.
|
# POSITION response sent back by the other end.
|
||||||
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
|
fully_qualified_stream_names = [
|
||||||
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
|
f"{self.synapse_stream_prefix}/{stream_suffix}"
|
||||||
|
for stream_suffix in self.synapse_channel_names
|
||||||
|
] + [self.synapse_stream_prefix]
|
||||||
|
logger.info("Sending redis SUBSCRIBE for %r", fully_qualified_stream_names)
|
||||||
|
await make_deferred_yieldable(self.subscribe(fully_qualified_stream_names))
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Successfully subscribed to redis stream, sending REPLICATE command"
|
"Successfully subscribed to redis stream, sending REPLICATE command"
|
||||||
)
|
)
|
||||||
|
@ -217,7 +223,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||||
|
|
||||||
await make_deferred_yieldable(
|
await make_deferred_yieldable(
|
||||||
self.synapse_outbound_redis_connection.publish(
|
self.synapse_outbound_redis_connection.publish(
|
||||||
self.synapse_stream_name, encoded_string
|
self.synapse_stream_prefix, encoded_string
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -300,20 +306,27 @@ def format_address(address: IAddress) -> str:
|
||||||
|
|
||||||
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
|
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
|
||||||
"""This is a reconnecting factory that connects to redis and immediately
|
"""This is a reconnecting factory that connects to redis and immediately
|
||||||
subscribes to a stream.
|
subscribes to some streams.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hs
|
hs
|
||||||
outbound_redis_connection: A connection to redis that will be used to
|
outbound_redis_connection: A connection to redis that will be used to
|
||||||
send outbound commands (this is separate to the redis connection
|
send outbound commands (this is separate to the redis connection
|
||||||
used to subscribe).
|
used to subscribe).
|
||||||
|
channel_names: A list of channel names to append to the base channel name
|
||||||
|
to additionally subscribe to.
|
||||||
|
e.g. if ['ABC', 'DEF'] is specified then we'll listen to:
|
||||||
|
example.com; example.com/ABC; and example.com/DEF.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
maxDelay = 5
|
maxDelay = 5
|
||||||
protocol = RedisSubscriber
|
protocol = RedisSubscriber
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hs: "HomeServer", outbound_redis_connection: txredisapi.ConnectionHandler
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
outbound_redis_connection: txredisapi.ConnectionHandler,
|
||||||
|
channel_names: List[str],
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -326,7 +339,8 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.synapse_handler = hs.get_replication_command_handler()
|
self.synapse_handler = hs.get_replication_command_handler()
|
||||||
self.synapse_stream_name = hs.hostname
|
self.synapse_stream_prefix = hs.hostname
|
||||||
|
self.synapse_channel_names = channel_names
|
||||||
|
|
||||||
self.synapse_outbound_redis_connection = outbound_redis_connection
|
self.synapse_outbound_redis_connection = outbound_redis_connection
|
||||||
|
|
||||||
|
@ -340,7 +354,8 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
|
||||||
# protocol.
|
# protocol.
|
||||||
p.synapse_handler = self.synapse_handler
|
p.synapse_handler = self.synapse_handler
|
||||||
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
|
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
|
||||||
p.synapse_stream_name = self.synapse_stream_name
|
p.synapse_stream_prefix = self.synapse_stream_prefix
|
||||||
|
p.synapse_channel_names = self.synapse_channel_names
|
||||||
|
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from collections import defaultdict
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from twisted.internet.address import IPv4Address
|
from twisted.internet.address import IPv4Address
|
||||||
from twisted.internet.protocol import Protocol
|
from twisted.internet.protocol import Protocol
|
||||||
|
@ -32,6 +33,7 @@ from synapse.server import HomeServer
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import FakeTransport
|
from tests.server import FakeTransport
|
||||||
|
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import hiredis
|
import hiredis
|
||||||
|
@ -475,22 +477,25 @@ class FakeRedisPubSubServer:
|
||||||
"""A fake Redis server for pub/sub."""
|
"""A fake Redis server for pub/sub."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._subscribers = set()
|
self._subscribers_by_channel: Dict[
|
||||||
|
bytes, Set["FakeRedisPubSubProtocol"]
|
||||||
|
] = defaultdict(set)
|
||||||
|
|
||||||
def add_subscriber(self, conn):
|
def add_subscriber(self, conn, channel: bytes):
|
||||||
"""A connection has called SUBSCRIBE"""
|
"""A connection has called SUBSCRIBE"""
|
||||||
self._subscribers.add(conn)
|
self._subscribers_by_channel[channel].add(conn)
|
||||||
|
|
||||||
def remove_subscriber(self, conn):
|
def remove_subscriber(self, conn):
|
||||||
"""A connection has called UNSUBSCRIBE"""
|
"""A connection has lost connection"""
|
||||||
self._subscribers.discard(conn)
|
for subscribers in self._subscribers_by_channel.values():
|
||||||
|
subscribers.discard(conn)
|
||||||
|
|
||||||
def publish(self, conn, channel, msg) -> int:
|
def publish(self, conn, channel: bytes, msg) -> int:
|
||||||
"""A connection want to publish a message to subscribers."""
|
"""A connection want to publish a message to subscribers."""
|
||||||
for sub in self._subscribers:
|
for sub in self._subscribers_by_channel[channel]:
|
||||||
sub.send(["message", channel, msg])
|
sub.send(["message", channel, msg])
|
||||||
|
|
||||||
return len(self._subscribers)
|
return len(self._subscribers_by_channel)
|
||||||
|
|
||||||
def buildProtocol(self, addr):
|
def buildProtocol(self, addr):
|
||||||
return FakeRedisPubSubProtocol(self)
|
return FakeRedisPubSubProtocol(self)
|
||||||
|
@ -531,9 +536,10 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||||
num_subscribers = self._server.publish(self, channel, message)
|
num_subscribers = self._server.publish(self, channel, message)
|
||||||
self.send(num_subscribers)
|
self.send(num_subscribers)
|
||||||
elif command == b"SUBSCRIBE":
|
elif command == b"SUBSCRIBE":
|
||||||
(channel,) = args
|
for idx, channel in enumerate(args):
|
||||||
self._server.add_subscriber(self)
|
num_channels = idx + 1
|
||||||
self.send(["subscribe", channel, 1])
|
self._server.add_subscriber(self, channel)
|
||||||
|
self.send(["subscribe", channel, num_channels])
|
||||||
|
|
||||||
# Since we use SET/GET to cache things we can safely no-op them.
|
# Since we use SET/GET to cache things we can safely no-op them.
|
||||||
elif command == b"SET":
|
elif command == b"SET":
|
||||||
|
@ -576,3 +582,27 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||||
|
|
||||||
def connectionLost(self, reason):
|
def connectionLost(self, reason):
|
||||||
self._server.remove_subscriber(self)
|
self._server.remove_subscriber(self)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
"""
|
||||||
|
A test case that enables Redis, providing a fake Redis server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not hiredis:
|
||||||
|
skip = "Requires hiredis"
|
||||||
|
|
||||||
|
if not USE_POSTGRES_FOR_TESTS:
|
||||||
|
# Redis replication only takes place on Postgres
|
||||||
|
skip = "Requires Postgres"
|
||||||
|
|
||||||
|
def default_config(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Overrides the default config to enable Redis.
|
||||||
|
Even if the test only uses make_worker_hs, the main process needs Redis
|
||||||
|
enabled otherwise it won't create a Fake Redis server to listen on the
|
||||||
|
Redis port and accept fake TCP connections.
|
||||||
|
"""
|
||||||
|
base = super().default_config()
|
||||||
|
base["redis"] = {"enabled": True}
|
||||||
|
return base
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from tests.replication._base import RedisMultiWorkerStreamTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
|
||||||
|
def test_subscribed_to_enough_redis_channels(self) -> None:
|
||||||
|
# The default main process is subscribed to the USER_IP channel.
|
||||||
|
self.assertCountEqual(
|
||||||
|
self.hs.get_replication_command_handler()._channels_to_subscribe_to,
|
||||||
|
["USER_IP"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_background_worker_subscribed_to_user_ip(self) -> None:
|
||||||
|
# The default main process is subscribed to the USER_IP channel.
|
||||||
|
worker1 = self.make_worker_hs(
|
||||||
|
"synapse.app.generic_worker",
|
||||||
|
extra_config={
|
||||||
|
"worker_name": "worker1",
|
||||||
|
"run_background_tasks_on": "worker1",
|
||||||
|
"redis": {"enabled": True},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertIn(
|
||||||
|
"USER_IP",
|
||||||
|
worker1.get_replication_command_handler()._channels_to_subscribe_to,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advance so the Redis subscription gets processed
|
||||||
|
self.pump(0.1)
|
||||||
|
|
||||||
|
# The counts are 2 because both the main process and the worker are subscribed.
|
||||||
|
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
|
||||||
|
self.assertEqual(
|
||||||
|
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 2
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_non_background_worker_not_subscribed_to_user_ip(self) -> None:
|
||||||
|
# The default main process is subscribed to the USER_IP channel.
|
||||||
|
worker2 = self.make_worker_hs(
|
||||||
|
"synapse.app.generic_worker",
|
||||||
|
extra_config={
|
||||||
|
"worker_name": "worker2",
|
||||||
|
"run_background_tasks_on": "worker1",
|
||||||
|
"redis": {"enabled": True},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertNotIn(
|
||||||
|
"USER_IP",
|
||||||
|
worker2.get_replication_command_handler()._channels_to_subscribe_to,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advance so the Redis subscription gets processed
|
||||||
|
self.pump(0.1)
|
||||||
|
|
||||||
|
# The count is 2 because both the main process and the worker are subscribed.
|
||||||
|
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
|
||||||
|
# For USER_IP, the count is 1 because only the main process is subscribed.
|
||||||
|
self.assertEqual(
|
||||||
|
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1
|
||||||
|
)
|
Loading…
Reference in New Issue