Add authentication to replication endpoints. (#8853)

Authentication is done by checking a shared secret provided
in the Synapse configuration file.
pull/8882/head
Patrick Cloke 2020-12-04 10:56:28 -05:00 committed by GitHub
parent df4b1e9c74
commit 96358cb424
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 184 additions and 15 deletions

1
changelog.d/8853.feature Normal file
View File

@ -0,0 +1 @@
Add optional HTTP authentication to replication endpoints.

View File

@ -2589,6 +2589,13 @@ opentracing:
# #
#run_background_tasks_on: worker1 #run_background_tasks_on: worker1
# A shared secret used by the replication APIs to authenticate HTTP requests
# from workers.
#
# By default this is unused and traffic is not authenticated.
#
#worker_replication_secret: ""
# Configuration for Redis when using workers. This *must* be enabled when # Configuration for Redis when using workers. This *must* be enabled when
# using workers (unless using old style direct TCP configuration). # using workers (unless using old style direct TCP configuration).

View File

@ -89,7 +89,8 @@ shared configuration file.
Normally, only a couple of changes are needed to make an existing configuration Normally, only a couple of changes are needed to make an existing configuration
file suitable for use with workers. First, you need to enable an "HTTP replication file suitable for use with workers. First, you need to enable an "HTTP replication
listener" for the main process; and secondly, you need to enable redis-based listener" for the main process; and secondly, you need to enable redis-based
replication. For example: replication. Optionally, a shared secret can be used to authenticate HTTP
traffic between workers. For example:
```yaml ```yaml
@ -103,6 +104,9 @@ listeners:
resources: resources:
- names: [replication] - names: [replication]
# Add a random shared secret to authenticate traffic.
worker_replication_secret: ""
redis: redis:
enabled: true enabled: true
``` ```

View File

@ -85,6 +85,9 @@ class WorkerConfig(Config):
# The port on the main synapse for HTTP replication endpoint # The port on the main synapse for HTTP replication endpoint
self.worker_replication_http_port = config.get("worker_replication_http_port") self.worker_replication_http_port = config.get("worker_replication_http_port")
# The shared secret used for authentication when connecting to the main synapse.
self.worker_replication_secret = config.get("worker_replication_secret", None)
self.worker_name = config.get("worker_name", self.worker_app) self.worker_name = config.get("worker_name", self.worker_app)
self.worker_main_http_uri = config.get("worker_main_http_uri", None) self.worker_main_http_uri = config.get("worker_main_http_uri", None)
@ -185,6 +188,13 @@ class WorkerConfig(Config):
# data). If not provided this defaults to the main process. # data). If not provided this defaults to the main process.
# #
#run_background_tasks_on: worker1 #run_background_tasks_on: worker1
# A shared secret used by the replication APIs to authenticate HTTP requests
# from workers.
#
# By default this is unused and traffic is not authenticated.
#
#worker_replication_secret: ""
""" """
def read_arguments(self, args): def read_arguments(self, args):

View File

@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
assert self.METHOD in ("PUT", "POST", "GET") assert self.METHOD in ("PUT", "POST", "GET")
self._replication_secret = None
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret
def _check_auth(self, request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if len(auth_headers) > 1:
raise RuntimeError("Too many Authorization headers.")
parts = auth_headers[0].split(b" ")
if parts[0] == b"Bearer" and len(parts) == 2:
received_secret = parts[1].decode("ascii")
if self._replication_secret == received_secret:
# Success!
return
raise RuntimeError("Invalid Authorization header.")
@abc.abstractmethod @abc.abstractmethod
async def _serialize_payload(**kwargs): async def _serialize_payload(**kwargs):
"""Static method that is called when creating a request. """Static method that is called when creating a request.
@ -150,6 +169,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME) outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
replication_secret = None
if hs.config.worker.worker_replication_secret:
replication_secret = hs.config.worker.worker_replication_secret.encode(
"ascii"
)
@trace(opname="outgoing_replication_request") @trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress() @outgoing_gauge.track_inprogress()
async def send_request(instance_name="master", **kwargs): async def send_request(instance_name="master", **kwargs):
@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# the master, and so whether we should clean up or not. # the master, and so whether we should clean up or not.
while True: while True:
headers = {} # type: Dict[bytes, List[bytes]] headers = {} # type: Dict[bytes, List[bytes]]
# Add an authorization header, if configured.
if replication_secret:
headers[b"Authorization"] = [b"Bearer " + replication_secret]
inject_active_span_byte_dict(headers, None, check_destination=False) inject_active_span_byte_dict(headers, None, check_destination=False)
try: try:
result = await request_func(uri, data, headers=headers) result = await request_func(uri, data, headers=headers)
@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
""" """
url_args = list(self.PATH_ARGS) url_args = list(self.PATH_ARGS)
handler = self._handle_request
method = self.METHOD method = self.METHOD
if self.CACHE: if self.CACHE:
handler = self._cached_handler # type: ignore
url_args.append("txn_id") url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
http_server.register_paths( http_server.register_paths(
method, [pattern], handler, self.__class__.__name__, method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
) )
def _cached_handler(self, request, txn_id, **kwargs): def _check_auth_and_handle(self, request, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks """Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that, if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response. otherwise calls `_handle_request` and caches its response.
@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# We just use the txn_id here, but we probably also want to use the # We just use the txn_id here, but we probably also want to use the
# other PATH_ARGS as well. # other PATH_ARGS as well.
assert self.CACHE # Check the authorization headers before handling the request.
if self._replication_secret:
self._check_auth(request)
return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs) if self.CACHE:
txn_id = kwargs.pop("txn_id")
return self.response_cache.wrap(
txn_id, self._handle_request, request, **kwargs
)
return self._handle_request(request, **kwargs)

View File

@ -0,0 +1,119 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Tuple
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, make_request
from tests.unittest import override_config
logger = logging.getLogger(__name__)
class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
"""Test the authentication of HTTP calls between workers."""
servlets = [register.register_servlets]
def make_homeserver(self, reactor, clock):
config = self.default_config()
# This isn't a real configuration option but is used to provide the main
# homeserver and worker homeserver different options.
main_replication_secret = config.pop("main_replication_secret", None)
if main_replication_secret:
config["worker_replication_secret"] = main_replication_secret
return self.setup_test_homeserver(config=config)
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config
def _test_register(self) -> Tuple[SynapseRequest, FakeChannel]:
"""Run the actual test:
1. Create a worker homeserver.
2. Start registration by providing a user/password.
3. Complete registration by providing dummy auth (this hits the main synapse).
4. Return the final request.
"""
worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]
request_1, channel_1 = make_request(
self.reactor,
site,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
self.assertEqual(request_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
return make_request(
self.reactor,
site,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
)
def test_no_auth(self):
"""With no authentication the request should finish.
"""
request, channel = self._test_register()
self.assertEqual(request.code, 200)
# We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test")
@override_config({"main_replication_secret": "my-secret"})
def test_missing_auth(self):
"""If the main process expects a secret that is not provided, an error results.
"""
request, channel = self._test_register()
self.assertEqual(request.code, 500)
@override_config(
{
"main_replication_secret": "my-secret",
"worker_replication_secret": "wrong-secret",
}
)
def test_unauthorized(self):
"""If the main process receives the wrong secret, an error results.
"""
request, channel = self._test_register()
self.assertEqual(request.code, 500)
@override_config({"worker_replication_secret": "my-secret"})
def test_authorized(self):
"""The request should finish when the worker provides the authentication header.
"""
request, channel = self._test_register()
self.assertEqual(request.code, 200)
# We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test")

View File

@ -14,27 +14,20 @@
# limitations under the License. # limitations under the License.
import logging import logging
from synapse.api.constants import LoginType
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha import register from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel, make_request from tests.server import FakeChannel, make_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Base class for tests of the replication streams""" """Test using one or more client readers for registration."""
servlets = [register.register_servlets] servlets = [register.register_servlets]
def prepare(self, reactor, clock, hs):
self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
def _get_worker_hs_config(self) -> dict: def _get_worker_hs_config(self) -> dict:
config = self.default_config() config = self.default_config()
config["worker_app"] = "synapse.app.client_reader" config["worker_app"] = "synapse.app.client_reader"