Fix incompatibility with Twisted < 21. (#10713)

Turns out that the functionality added in #10546 to skip TLS was incompatible
with older Twisted versions, so we need to be a bit more inventive.

Also, add a test to (hopefully) not break this in future. Sadly, testing TLS is
really hard.
pull/10723/head
Richard van der Hoff 2021-08-27 16:33:41 +01:00 committed by GitHub
parent f03cafb50c
commit 8f98260552
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 173 additions and 21 deletions

1
changelog.d/10713.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a regression introduced in Synapse 1.41 which broke email transmission on Systems using older versions of the Twisted library.

View File

@ -87,6 +87,7 @@ files =
tests/test_utils, tests/test_utils,
tests/handlers/test_password_providers.py, tests/handlers/test_password_providers.py,
tests/handlers/test_room_summary.py, tests/handlers/test_room_summary.py,
tests/handlers/test_send_email.py,
tests/rest/client/v1/test_login.py, tests/rest/client/v1/test_login.py,
tests/rest/client/v2_alpha/test_auth.py, tests/rest/client/v2_alpha/test_auth.py,
tests/util/test_itertools.py, tests/util/test_itertools.py,

View File

@ -19,9 +19,12 @@ from email.mime.text import MIMEText
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from pkg_resources import parse_version
import twisted
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IReactorTCP from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorTCP
from twisted.mail.smtp import ESMTPSenderFactory from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -30,6 +33,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_old_twisted = parse_version(twisted.__version__) < parse_version("21")
class _NoTLSESMTPSender(ESMTPSender):
"""Extend ESMTPSender to disable TLS
Unfortunately, before Twisted 21.2, ESMTPSender doesn't give an easy way to disable
TLS, so we override its internal method which it uses to generate a context factory.
"""
def _getContextFactory(self) -> Optional[IOpenSSLContextFactory]:
return None
async def _sendmail( async def _sendmail(
reactor: IReactorTCP, reactor: IReactorTCP,
@ -42,7 +58,7 @@ async def _sendmail(
password: Optional[bytes] = None, password: Optional[bytes] = None,
require_auth: bool = False, require_auth: bool = False,
require_tls: bool = False, require_tls: bool = False,
tls_hostname: Optional[str] = None, enable_tls: bool = True,
) -> None: ) -> None:
"""A simple wrapper around ESMTPSenderFactory, to allow substitution in tests """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests
@ -57,13 +73,14 @@ async def _sendmail(
password: password to give when authenticating password: password to give when authenticating
require_auth: if auth is not offered, fail the request require_auth: if auth is not offered, fail the request
require_tls: if TLS is not offered, fail the reqest require_tls: if TLS is not offered, fail the reqest
tls_hostname: TLS hostname to check for. None to disable TLS. enable_tls: True to enable TLS. If this is False and require_tls is True,
the request will fail.
""" """
msg = BytesIO(msg_bytes) msg = BytesIO(msg_bytes)
d: "Deferred[object]" = Deferred() d: "Deferred[object]" = Deferred()
factory = ESMTPSenderFactory( def build_sender_factory(**kwargs) -> ESMTPSenderFactory:
return ESMTPSenderFactory(
username, username,
password, password,
from_addr, from_addr,
@ -73,9 +90,21 @@ async def _sendmail(
heloFallback=True, heloFallback=True,
requireAuthentication=require_auth, requireAuthentication=require_auth,
requireTransportSecurity=require_tls, requireTransportSecurity=require_tls,
hostname=tls_hostname, **kwargs,
) )
if _is_old_twisted:
# before twisted 21.2, we have to override the ESMTPSender protocol to disable
# TLS
factory = build_sender_factory()
if not enable_tls:
factory.protocol = _NoTLSESMTPSender
else:
# for twisted 21.2 and later, there is a 'hostname' parameter which we should
# set to enable TLS.
factory = build_sender_factory(hostname=smtphost if enable_tls else None)
# the IReactorTCP interface claims host has to be a bytes, which seems to be wrong # the IReactorTCP interface claims host has to be a bytes, which seems to be wrong
reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type] reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type]
@ -154,5 +183,5 @@ class SendEmailHandler:
password=self._smtp_pass, password=self._smtp_pass,
require_auth=self._smtp_user is not None, require_auth=self._smtp_user is not None,
require_tls=self._require_transport_security, require_tls=self._require_transport_security,
tls_hostname=self._smtp_host if self._enable_tls else None, enable_tls=self._enable_tls,
) )

View File

@ -0,0 +1,112 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.address import IPv4Address
from twisted.internet.defer import ensureDeferred
from twisted.mail import interfaces, smtp
from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase
@implementer(interfaces.IMessageDelivery)
class _DummyMessageDelivery:
def __init__(self):
# (recipient, message) tuples
self.messages: List[Tuple[smtp.Address, bytes]] = []
def receivedHeader(self, helo, origin, recipients):
return None
def validateFrom(self, helo, origin):
return origin
def record_message(self, recipient: smtp.Address, message: bytes):
self.messages.append((recipient, message))
def validateTo(self, user: smtp.User):
return lambda: _DummyMessage(self, user)
@implementer(interfaces.IMessageSMTP)
class _DummyMessage:
"""IMessageSMTP implementation which saves the message delivered to it
to the _DummyMessageDelivery object.
"""
def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User):
self._delivery = delivery
self._user = user
self._buffer: List[bytes] = []
def lineReceived(self, line):
self._buffer.append(line)
def eomReceived(self):
message = b"\n".join(self._buffer) + b"\n"
self._delivery.record_message(self._user.dest, message)
return defer.succeed(b"saved")
def connectionLost(self):
pass
class SendEmailHandlerTestCase(HomeserverTestCase):
def test_send_email(self):
"""Happy-path test that we can send email to a non-TLS server."""
h = self.hs.get_send_email_handler()
d = ensureDeferred(
h.send_email(
"foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
)
)
# there should be an attempt to connect to localhost:25
self.assertEqual(len(self.reactor.tcpClients), 1)
(host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
0
]
self.assertEqual(host, "localhost")
self.assertEqual(port, 25)
# wire it up to an SMTP server
message_delivery = _DummyMessageDelivery()
server_protocol = smtp.ESMTP()
server_protocol.delivery = message_delivery
# make sure that the server uses the test reactor to set timeouts
server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
client_protocol = client_factory.buildProtocol(None)
client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
server_protocol.makeConnection(
FakeTransport(
client_protocol,
self.reactor,
peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
)
)
# the message should now get delivered
self.get_success(d, by=0.1)
# check it arrived
self.assertEqual(len(message_delivery.messages), 1)
user, msg = message_delivery.messages.pop()
self.assertEqual(str(user), "foo@bar.com")
self.assertIn(b"Subject: test subject", msg)

View File

@ -10,9 +10,10 @@ from zope.interface import implementer
from twisted.internet import address, threads, udp from twisted.internet import address, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, succeed from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import ( from twisted.internet.interfaces import (
IAddress,
IHostnameResolver, IHostnameResolver,
IProtocol, IProtocol,
IPullProducer, IPullProducer,
@ -511,6 +512,9 @@ class FakeTransport:
will get called back for connectionLost() notifications etc. will get called back for connectionLost() notifications etc.
""" """
_peer_address: Optional[IAddress] = attr.ib(default=None)
"""The value to be returend by getPeer"""
disconnecting = False disconnecting = False
disconnected = False disconnected = False
connected = True connected = True
@ -519,7 +523,7 @@ class FakeTransport:
autoflush = attr.ib(default=True) autoflush = attr.ib(default=True)
def getPeer(self): def getPeer(self):
return None return self._peer_address
def getHost(self): def getHost(self):
return None return None
@ -572,7 +576,12 @@ class FakeTransport:
self.producerStreaming = streaming self.producerStreaming = streaming
def _produce(): def _produce():
d = self.producer.resumeProducing() if not self.producer:
# we've been unregistered
return
# some implementations of IProducer (for example, FileSender)
# don't return a deferred.
d = maybeDeferred(self.producer.resumeProducing)
d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
if not streaming: if not streaming: