188 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			188 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			Python
		
	
	
| # Copyright 2023 The Matrix.org Foundation C.I.C.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import logging
 | |
| from typing import Dict, Optional
 | |
| 
 | |
| from zope.interface import implementer
 | |
| 
 | |
| from twisted.internet import defer
 | |
| from twisted.internet.endpoints import (
 | |
|     HostnameEndpoint,
 | |
|     UNIXClientEndpoint,
 | |
|     wrapClientTLS,
 | |
| )
 | |
| from twisted.internet.interfaces import IStreamClientEndpoint
 | |
| from twisted.python.failure import Failure
 | |
| from twisted.web.client import URI, HTTPConnectionPool, _AgentBase
 | |
| from twisted.web.error import SchemeNotSupported
 | |
| from twisted.web.http_headers import Headers
 | |
| from twisted.web.iweb import (
 | |
|     IAgent,
 | |
|     IAgentEndpointFactory,
 | |
|     IBodyProducer,
 | |
|     IPolicyForHTTPS,
 | |
|     IResponse,
 | |
| )
 | |
| 
 | |
| from synapse.config.workers import (
 | |
|     InstanceLocationConfig,
 | |
|     InstanceTcpLocationConfig,
 | |
|     InstanceUnixLocationConfig,
 | |
| )
 | |
| from synapse.types import ISynapseReactor
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| @implementer(IAgentEndpointFactory)
 | |
| class ReplicationEndpointFactory:
 | |
|     """Connect to a given TCP or UNIX socket"""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         reactor: ISynapseReactor,
 | |
|         instance_map: Dict[str, InstanceLocationConfig],
 | |
|         context_factory: IPolicyForHTTPS,
 | |
|     ) -> None:
 | |
|         self.reactor = reactor
 | |
|         self.instance_map = instance_map
 | |
|         self.context_factory = context_factory
 | |
| 
 | |
|     def endpointForURI(self, uri: URI) -> IStreamClientEndpoint:
 | |
|         """
 | |
|         This part of the factory decides what kind of endpoint is being connected to.
 | |
| 
 | |
|         Args:
 | |
|             uri: The pre-parsed URI object containing all the uri data
 | |
| 
 | |
|         Returns: The correct client endpoint object
 | |
|         """
 | |
|         # The given URI has a special scheme and includes the worker name. The
 | |
|         # actual connection details are pulled from the instance map.
 | |
|         worker_name = uri.netloc.decode("utf-8")
 | |
|         location_config = self.instance_map[worker_name]
 | |
|         scheme = location_config.scheme()
 | |
| 
 | |
|         if isinstance(location_config, InstanceTcpLocationConfig):
 | |
|             endpoint = HostnameEndpoint(
 | |
|                 self.reactor,
 | |
|                 location_config.host,
 | |
|                 location_config.port,
 | |
|             )
 | |
|             if scheme == "https":
 | |
|                 endpoint = wrapClientTLS(
 | |
|                     # The 'port' argument below isn't actually used by the function
 | |
|                     self.context_factory.creatorForNetloc(
 | |
|                         location_config.host.encode("utf-8"),
 | |
|                         location_config.port,
 | |
|                     ),
 | |
|                     endpoint,
 | |
|                 )
 | |
|             return endpoint
 | |
|         elif isinstance(location_config, InstanceUnixLocationConfig):
 | |
|             return UNIXClientEndpoint(self.reactor, location_config.path)
 | |
|         else:
 | |
|             raise SchemeNotSupported(f"Unsupported scheme: {scheme}")
 | |
| 
 | |
| 
 | |
| @implementer(IAgent)
 | |
| class ReplicationAgent(_AgentBase):
 | |
|     """
 | |
|     Client for connecting to replication endpoints via HTTP and HTTPS.
 | |
| 
 | |
|     Much of this code is copied from Twisted's twisted.web.client.Agent.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         reactor: ISynapseReactor,
 | |
|         instance_map: Dict[str, InstanceLocationConfig],
 | |
|         contextFactory: IPolicyForHTTPS,
 | |
|         connectTimeout: Optional[float] = None,
 | |
|         bindAddress: Optional[bytes] = None,
 | |
|         pool: Optional[HTTPConnectionPool] = None,
 | |
|     ):
 | |
|         """
 | |
|         Create a ReplicationAgent.
 | |
| 
 | |
|         Args:
 | |
|             reactor: A reactor for this Agent to place outgoing connections.
 | |
|             contextFactory: A factory for TLS contexts, to control the
 | |
|                 verification parameters of OpenSSL.  The default is to use a
 | |
|                 BrowserLikePolicyForHTTPS, so unless you have special
 | |
|                 requirements you can leave this as-is.
 | |
|             connectTimeout: The amount of time that this Agent will wait
 | |
|                 for the peer to accept a connection.
 | |
|             bindAddress: The local address for client sockets to bind to.
 | |
|             pool: An HTTPConnectionPool instance, or None, in which
 | |
|                 case a non-persistent HTTPConnectionPool instance will be
 | |
|                 created.
 | |
|         """
 | |
|         _AgentBase.__init__(self, reactor, pool)
 | |
|         endpoint_factory = ReplicationEndpointFactory(
 | |
|             reactor, instance_map, contextFactory
 | |
|         )
 | |
|         self._endpointFactory = endpoint_factory
 | |
| 
 | |
|     def request(
 | |
|         self,
 | |
|         method: bytes,
 | |
|         uri: bytes,
 | |
|         headers: Optional[Headers] = None,
 | |
|         bodyProducer: Optional[IBodyProducer] = None,
 | |
|     ) -> "defer.Deferred[IResponse]":
 | |
|         """
 | |
|         Issue a request to the server indicated by the given uri.
 | |
| 
 | |
|         An existing connection from the connection pool may be used or a new
 | |
|         one may be created.
 | |
| 
 | |
|         Currently, HTTP, HTTPS and UNIX schemes are supported in uri.
 | |
| 
 | |
|         This is copied from twisted.web.client.Agent, except:
 | |
| 
 | |
|         * It uses a different pool key (combining the scheme with either host & port or
 | |
|           socket path).
 | |
|         * It does not call _ensureValidURI(...) as the strictness of IDNA2008 is not
 | |
|           required when using a worker's name as a 'hostname' for Synapse HTTP
 | |
|           Replication machinery. Specifically, this allows a range of ascii characters
 | |
|           such as '+' and '_' in hostnames/worker's names.
 | |
| 
 | |
|         See: twisted.web.iweb.IAgent.request
 | |
|         """
 | |
|         parsedURI = URI.fromBytes(uri)
 | |
|         try:
 | |
|             endpoint = self._endpointFactory.endpointForURI(parsedURI)
 | |
|         except SchemeNotSupported:
 | |
|             return defer.fail(Failure())
 | |
| 
 | |
|         worker_name = parsedURI.netloc.decode("utf-8")
 | |
|         key_scheme = self._endpointFactory.instance_map[worker_name].scheme()
 | |
|         key_netloc = self._endpointFactory.instance_map[worker_name].netloc()
 | |
|         # This sets the Pool key to be:
 | |
|         #  (http(s), <host:port>) or (unix, <socket_path>)
 | |
|         key = (key_scheme, key_netloc)
 | |
| 
 | |
|         # _requestWithEndpoint comes from _AgentBase class
 | |
|         return self._requestWithEndpoint(
 | |
|             key,
 | |
|             endpoint,
 | |
|             method,
 | |
|             parsedURI,
 | |
|             headers,
 | |
|             bodyProducer,
 | |
|             parsedURI.originForm,
 | |
|         )
 |