Add ability to run multiple pusher instances ()

This reuses the same scheme as federation sender sharding
pull/7803/head
Erik Johnston 2020-07-16 14:06:28 +01:00 committed by GitHub
parent a827838706
commit 649a7ead5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 293 additions and 82 deletions

1
changelog.d/7855.feature Normal file
View File

@ -0,0 +1 @@
Add experimental support for running multiple pusher workers.

View File

@ -19,9 +19,11 @@ import argparse
import errno
import os
from collections import OrderedDict
from hashlib import sha256
from textwrap import dedent
from typing import Any, MutableMapping, Optional
from typing import Any, List, MutableMapping, Optional
import attr
import yaml
@ -717,4 +719,36 @@ def find_config_files(search_paths):
return config_files
__all__ = ["Config", "RootConfig"]
@attr.s
class ShardedWorkerHandlingConfig:
"""Algorithm for choosing which instance is responsible for handling some
sharded work.
For example, the federation senders use this to determine which instances
handles sending stuff to a given destination (which is used as the `key`
below).
"""
instances = attr.ib(type=List[str])
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key.
"""
# If multiple instances are not defined we always return true.
if not self.instances or len(self.instances) == 1:
return True
# We shard by taking the hash, modulo it by the number of instances and
# then checking whether this instance matches the instance at that
# index.
#
# (Technically this introduces some bias and is not entirely uniform,
# but since the hash is so large the bias is ridiculously small).
dest_hash = sha256(key.encode("utf8")).digest()
dest_int = int.from_bytes(dest_hash, byteorder="little")
remainder = dest_int % (len(self.instances))
return self.instances[remainder] == instance_name
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]

View File

@ -137,3 +137,8 @@ class Config:
def read_config_files(config_files: List[str]): ...
def find_config_files(search_paths: List[str]): ...
class ShardedWorkerHandlingConfig:
instances: List[str]
def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ...

View File

@ -13,42 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from hashlib import sha256
from typing import List, Optional
from typing import Optional
import attr
from netaddr import IPSet
from ._base import Config, ConfigError
@attr.s
class ShardedFederationSendingConfig:
"""Algorithm for choosing which federation sender instance is responsible
for which destionation host.
"""
instances = attr.ib(type=List[str])
def should_send_to(self, instance_name: str, destination: str) -> bool:
"""Whether this instance is responsible for sending transcations for
the given host.
"""
# If multiple federation senders are not defined we always return true.
if not self.instances or len(self.instances) == 1:
return True
# We shard by taking the hash, modulo it by the number of federation
# senders and then checking whether this instance matches the instance
# at that index.
#
# (Technically this introduces some bias and is not entirely uniform, but
# since the hash is so large the bias is ridiculously small).
dest_hash = sha256(destination.encode("utf8")).digest()
dest_int = int.from_bytes(dest_hash, byteorder="little")
remainder = dest_int % (len(self.instances))
return self.instances[remainder] == instance_name
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
class FederationConfig(Config):
@ -61,7 +30,7 @@ class FederationConfig(Config):
self.send_federation = config.get("send_federation", True)
federation_sender_instances = config.get("federation_sender_instances") or []
self.federation_shard_config = ShardedFederationSendingConfig(
self.federation_shard_config = ShardedWorkerHandlingConfig(
federation_sender_instances
)

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
from ._base import Config, ShardedWorkerHandlingConfig
class PushConfig(Config):
@ -24,6 +24,9 @@ class PushConfig(Config):
push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True)
pusher_instances = config.get("pusher_instances") or []
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
# There was a a 'redact_content' setting but mistakenly read from the
# 'email'section'. Check for the flag in the 'push' section, and log,
# but do not honour it to avoid nasty surprises when people upgrade.

View File

@ -197,7 +197,7 @@ class FederationSender(object):
destinations = {
d
for d in destinations
if self._federation_shard_config.should_send_to(
if self._federation_shard_config.should_handle(
self._instance_name, d
)
}
@ -335,7 +335,7 @@ class FederationSender(object):
d
for d in domains
if d != self.server_name
and self._federation_shard_config.should_send_to(self._instance_name, d)
and self._federation_shard_config.should_handle(self._instance_name, d)
]
if not domains:
return
@ -441,7 +441,7 @@ class FederationSender(object):
for destination in destinations:
if destination == self.server_name:
continue
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
continue
@ -460,7 +460,7 @@ class FederationSender(object):
if destination == self.server_name:
continue
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
continue
@ -486,7 +486,7 @@ class FederationSender(object):
logger.info("Not sending EDU to ourselves")
return
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
return
@ -507,7 +507,7 @@ class FederationSender(object):
edu: edu to send
key: clobbering key for this edu
"""
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, edu.destination
):
return
@ -523,7 +523,7 @@ class FederationSender(object):
logger.warning("Not sending device update to ourselves")
return
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
return
@ -541,7 +541,7 @@ class FederationSender(object):
logger.warning("Not waking up ourselves")
return
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
return

View File

@ -78,7 +78,7 @@ class PerDestinationQueue(object):
self._federation_shard_config = hs.config.federation.federation_shard_config
self._should_send_on_this_instance = True
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
# We don't raise an exception here to avoid taking out any other

View File

@ -15,13 +15,12 @@
# limitations under the License.
import logging
from collections import defaultdict
from threading import Lock
from typing import Dict, Tuple, Union
from typing import TYPE_CHECKING, Dict, Union
from prometheus_client import Gauge
from twisted.internet import defer
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
from synapse.push.emailpusher import EmailPusher
@ -29,9 +28,18 @@ from synapse.push.httppusher import HttpPusher
from synapse.push.pusher import PusherFactory
from synapse.util.async_helpers import concurrently_execute
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
synapse_pushers = Gauge(
"synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"]
)
class PusherPool:
"""
The pusher pool. This is responsible for dispatching notifications of new events to
@ -47,36 +55,20 @@ class PusherPool:
Pusher.on_new_receipts are not expected to return deferreds.
"""
def __init__(self, _hs):
self.hs = _hs
self.pusher_factory = PusherFactory(_hs)
self._should_start_pushers = _hs.config.start_pushers
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.pusher_factory = PusherFactory(hs)
self._should_start_pushers = hs.config.start_pushers
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
# map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
# a lock for the pushers dict, since `count_pushers` is called from an different
# and we otherwise get concurrent modification errors
self._pushers_lock = Lock()
def count_pushers():
results = defaultdict(int) # type: Dict[Tuple[str, str], int]
with self._pushers_lock:
for pushers in self.pushers.values():
for pusher in pushers.values():
k = (type(pusher).__name__, pusher.app_id)
results[k] += 1
return results
LaterGauge(
name="synapse_pushers",
desc="the number of active pushers",
labels=["kind", "app_id"],
caller=count_pushers,
)
def start(self):
"""Starts the pushers off in a background process.
"""
@ -104,6 +96,7 @@ class PusherPool:
Returns:
Deferred[EmailPusher|HttpPusher]
"""
time_now_msec = self.clock.time_msec()
# we try to create the pusher just to validate the config: it
@ -176,6 +169,9 @@ class PusherPool:
access_tokens (Iterable[int]): access token *ids* to remove pushers
for
"""
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)):
if p["access_token"] in tokens:
@ -237,6 +233,9 @@ class PusherPool:
if not self._should_start_pushers:
return
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_dict = None
@ -275,6 +274,11 @@ class PusherPool:
Returns:
Deferred[EmailPusher|HttpPusher]
"""
if not self._pusher_shard_config.should_handle(
self._instance_name, pusherdict["user_name"]
):
return
try:
p = self.pusher_factory.create_pusher(pusherdict)
except PusherConfigException as e:
@ -298,12 +302,13 @@ class PusherPool:
appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
with self._pushers_lock:
byuser = self.pushers.setdefault(pusherdict["user_name"], {})
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
synapse_pushers.labels(type(p).__name__, p.app_id).inc()
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
# push.
@ -330,9 +335,10 @@ class PusherPool:
if appid_pushkey in byuser:
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
byuser[appid_pushkey].on_stop()
with self._pushers_lock:
del byuser[appid_pushkey]
pusher = byuser.pop(appid_pushkey)
pusher.on_stop()
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id

View File

@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from mock import Mock
from twisted.internet import defer
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests.replication._base import BaseMultiWorkerStreamTestCase
logger = logging.getLogger(__name__)
class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks pusher sharding works
"""
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
# Register a user who sends a message that we'll get notified about
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
def default_config(self):
conf = super().default_config()
conf["start_pushers"] = False
return conf
def _create_pusher_and_send_msg(self, localpart):
# Create a user that will get push notifications
user_id = self.register_user(localpart, "pass")
access_token = self.login(localpart, "pass")
# Register a pusher
user_dict = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_dict["token_id"]
self.get_success(
self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
app_id="m.http",
app_display_name="HTTP Push Notifications",
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "https://push.example.com/push"},
)
)
self.pump()
# Create a room
room = self.helper.create_room_as(user_id, tok=access_token)
# The other user joins
self.helper.join(
room=room, user=self.other_user_id, tok=self.other_access_token
)
# The other user sends some messages
response = self.helper.send(room, body="Hi!", tok=self.other_access_token)
event_id = response["event_id"]
return event_id
def test_send_push_single_worker(self):
"""Test that registration works when using a pusher worker.
"""
http_client_mock = Mock(spec_set=["post_json_get_json"])
http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
{}
)
self.make_worker_hs(
"synapse.app.pusher",
{"start_pushers": True},
proxied_http_client=http_client_mock,
)
event_id = self._create_pusher_and_send_msg("user")
# Advance time a bit, so the pusher will register something has happened
self.pump()
http_client_mock.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock.post_json_get_json.call_args[0][0],
"https://push.example.com/push",
)
self.assertEqual(
event_id,
http_client_mock.post_json_get_json.call_args[0][1]["notification"][
"event_id"
],
)
def test_send_push_multiple_workers(self):
"""Test that registration works when using sharded pusher workers.
"""
http_client_mock1 = Mock(spec_set=["post_json_get_json"])
http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
{}
)
self.make_worker_hs(
"synapse.app.pusher",
{
"start_pushers": True,
"worker_name": "pusher1",
"pusher_instances": ["pusher1", "pusher2"],
},
proxied_http_client=http_client_mock1,
)
http_client_mock2 = Mock(spec_set=["post_json_get_json"])
http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
{}
)
self.make_worker_hs(
"synapse.app.pusher",
{
"start_pushers": True,
"worker_name": "pusher2",
"pusher_instances": ["pusher1", "pusher2"],
},
proxied_http_client=http_client_mock2,
)
# We choose a user name that we know should go to pusher1.
event_id = self._create_pusher_and_send_msg("user2")
# Advance time a bit, so the pusher will register something has happened
self.pump()
http_client_mock1.post_json_get_json.assert_called_once()
http_client_mock2.post_json_get_json.assert_not_called()
self.assertEqual(
http_client_mock1.post_json_get_json.call_args[0][0],
"https://push.example.com/push",
)
self.assertEqual(
event_id,
http_client_mock1.post_json_get_json.call_args[0][1]["notification"][
"event_id"
],
)
http_client_mock1.post_json_get_json.reset_mock()
http_client_mock2.post_json_get_json.reset_mock()
# Now we choose a user name that we know should go to pusher2.
event_id = self._create_pusher_and_send_msg("user4")
# Advance time a bit, so the pusher will register something has happened
self.pump()
http_client_mock1.post_json_get_json.assert_not_called()
http_client_mock2.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock2.post_json_get_json.call_args[0][0],
"https://push.example.com/push",
)
self.assertEqual(
event_id,
http_client_mock2.post_json_get_json.call_args[0][1]["notification"][
"event_id"
],
)