Add missing type hints to synapse.crypto. (#11146)
And require type hints for this module.pull/11181/head
parent
09eff1b3db
commit
0f9adc99ad
|
@ -0,0 +1 @@
|
||||||
|
Add missing type hints to `synapse.crypto`.
|
3
mypy.ini
3
mypy.ini
|
@ -103,6 +103,9 @@ files =
|
||||||
[mypy-synapse.api.*]
|
[mypy-synapse.api.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.crypto.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.events.*]
|
[mypy-synapse.events.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
|
|
@ -29,9 +29,12 @@ from twisted.internet.ssl import (
|
||||||
TLSVersion,
|
TLSVersion,
|
||||||
platformTrust,
|
platformTrust,
|
||||||
)
|
)
|
||||||
|
from twisted.protocols.tls import TLSMemoryBIOProtocol
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.web.iweb import IPolicyForHTTPS
|
from twisted.web.iweb import IPolicyForHTTPS
|
||||||
|
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,7 +54,7 @@ class ServerContextFactory(ContextFactory):
|
||||||
per https://github.com/matrix-org/synapse/issues/1691
|
per https://github.com/matrix-org/synapse/issues/1691
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config: HomeServerConfig):
|
||||||
# TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version,
|
# TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version,
|
||||||
# switch to those (see https://github.com/pyca/cryptography/issues/5379).
|
# switch to those (see https://github.com/pyca/cryptography/issues/5379).
|
||||||
#
|
#
|
||||||
|
@ -64,7 +67,7 @@ class ServerContextFactory(ContextFactory):
|
||||||
self.configure_context(self._context, config)
|
self.configure_context(self._context, config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def configure_context(context, config):
|
def configure_context(context: SSL.Context, config: HomeServerConfig) -> None:
|
||||||
try:
|
try:
|
||||||
_ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
|
_ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
|
||||||
context.set_tmp_ecdh(_ecCurve)
|
context.set_tmp_ecdh(_ecCurve)
|
||||||
|
@ -75,14 +78,15 @@ class ServerContextFactory(ContextFactory):
|
||||||
SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1
|
SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1
|
||||||
)
|
)
|
||||||
context.use_certificate_chain_file(config.tls.tls_certificate_file)
|
context.use_certificate_chain_file(config.tls.tls_certificate_file)
|
||||||
|
assert config.tls.tls_private_key is not None
|
||||||
context.use_privatekey(config.tls.tls_private_key)
|
context.use_privatekey(config.tls.tls_private_key)
|
||||||
|
|
||||||
# https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
|
# https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
|
||||||
context.set_cipher_list(
|
context.set_cipher_list(
|
||||||
"ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM"
|
b"ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM"
|
||||||
)
|
)
|
||||||
|
|
||||||
def getContext(self):
|
def getContext(self) -> SSL.Context:
|
||||||
return self._context
|
return self._context
|
||||||
|
|
||||||
|
|
||||||
|
@ -98,7 +102,7 @@ class FederationPolicyForHTTPS:
|
||||||
constructs an SSLClientConnectionCreator factory accordingly.
|
constructs an SSLClientConnectionCreator factory accordingly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config: HomeServerConfig):
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
# Check if we're using a custom list of a CA certificates
|
# Check if we're using a custom list of a CA certificates
|
||||||
|
@ -131,7 +135,7 @@ class FederationPolicyForHTTPS:
|
||||||
self._config.tls.federation_certificate_verification_whitelist
|
self._config.tls.federation_certificate_verification_whitelist
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_options(self, host: bytes):
|
def get_options(self, host: bytes) -> IOpenSSLClientConnectionCreator:
|
||||||
# IPolicyForHTTPS.get_options takes bytes, but we want to compare
|
# IPolicyForHTTPS.get_options takes bytes, but we want to compare
|
||||||
# against the str whitelist. The hostnames in the whitelist are already
|
# against the str whitelist. The hostnames in the whitelist are already
|
||||||
# IDNA-encoded like the hosts will be here.
|
# IDNA-encoded like the hosts will be here.
|
||||||
|
@ -153,7 +157,9 @@ class FederationPolicyForHTTPS:
|
||||||
|
|
||||||
return SSLClientConnectionCreator(host, ssl_context, should_verify)
|
return SSLClientConnectionCreator(host, ssl_context, should_verify)
|
||||||
|
|
||||||
def creatorForNetloc(self, hostname, port):
|
def creatorForNetloc(
|
||||||
|
self, hostname: bytes, port: int
|
||||||
|
) -> IOpenSSLClientConnectionCreator:
|
||||||
"""Implements the IPolicyForHTTPS interface so that this can be passed
|
"""Implements the IPolicyForHTTPS interface so that this can be passed
|
||||||
directly to agents.
|
directly to agents.
|
||||||
"""
|
"""
|
||||||
|
@ -169,16 +175,18 @@ class RegularPolicyForHTTPS:
|
||||||
trust root.
|
trust root.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
trust_root = platformTrust()
|
trust_root = platformTrust()
|
||||||
self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext()
|
self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext()
|
||||||
self._ssl_context.set_info_callback(_context_info_cb)
|
self._ssl_context.set_info_callback(_context_info_cb)
|
||||||
|
|
||||||
def creatorForNetloc(self, hostname, port):
|
def creatorForNetloc(
|
||||||
|
self, hostname: bytes, port: int
|
||||||
|
) -> IOpenSSLClientConnectionCreator:
|
||||||
return SSLClientConnectionCreator(hostname, self._ssl_context, True)
|
return SSLClientConnectionCreator(hostname, self._ssl_context, True)
|
||||||
|
|
||||||
|
|
||||||
def _context_info_cb(ssl_connection, where, ret):
|
def _context_info_cb(ssl_connection: SSL.Connection, where: int, ret: int) -> None:
|
||||||
"""The 'information callback' for our openssl context objects.
|
"""The 'information callback' for our openssl context objects.
|
||||||
|
|
||||||
Note: Once this is set as the info callback on a Context object, the Context should
|
Note: Once this is set as the info callback on a Context object, the Context should
|
||||||
|
@ -204,11 +212,13 @@ class SSLClientConnectionCreator:
|
||||||
Replaces twisted.internet.ssl.ClientTLSOptions
|
Replaces twisted.internet.ssl.ClientTLSOptions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hostname: bytes, ctx, verify_certs: bool):
|
def __init__(self, hostname: bytes, ctx: SSL.Context, verify_certs: bool):
|
||||||
self._ctx = ctx
|
self._ctx = ctx
|
||||||
self._verifier = ConnectionVerifier(hostname, verify_certs)
|
self._verifier = ConnectionVerifier(hostname, verify_certs)
|
||||||
|
|
||||||
def clientConnectionForTLS(self, tls_protocol):
|
def clientConnectionForTLS(
|
||||||
|
self, tls_protocol: TLSMemoryBIOProtocol
|
||||||
|
) -> SSL.Connection:
|
||||||
context = self._ctx
|
context = self._ctx
|
||||||
connection = SSL.Connection(context, None)
|
connection = SSL.Connection(context, None)
|
||||||
|
|
||||||
|
@ -219,7 +229,7 @@ class SSLClientConnectionCreator:
|
||||||
# ... and we also gut-wrench a '_synapse_tls_verifier' attribute into the
|
# ... and we also gut-wrench a '_synapse_tls_verifier' attribute into the
|
||||||
# tls_protocol so that the SSL context's info callback has something to
|
# tls_protocol so that the SSL context's info callback has something to
|
||||||
# call to do the cert verification.
|
# call to do the cert verification.
|
||||||
tls_protocol._synapse_tls_verifier = self._verifier
|
tls_protocol._synapse_tls_verifier = self._verifier # type: ignore[attr-defined]
|
||||||
return connection
|
return connection
|
||||||
|
|
||||||
|
|
||||||
|
@ -244,7 +254,9 @@ class ConnectionVerifier:
|
||||||
self._hostnameBytes = hostname
|
self._hostnameBytes = hostname
|
||||||
self._hostnameASCII = self._hostnameBytes.decode("ascii")
|
self._hostnameASCII = self._hostnameBytes.decode("ascii")
|
||||||
|
|
||||||
def verify_context_info_cb(self, ssl_connection, where):
|
def verify_context_info_cb(
|
||||||
|
self, ssl_connection: SSL.Connection, where: int
|
||||||
|
) -> None:
|
||||||
if where & SSL.SSL_CB_HANDSHAKE_START and not self._is_ip_address:
|
if where & SSL.SSL_CB_HANDSHAKE_START and not self._is_ip_address:
|
||||||
ssl_connection.set_tlsext_host_name(self._hostnameBytes)
|
ssl_connection.set_tlsext_host_name(self._hostnameBytes)
|
||||||
|
|
||||||
|
|
|
@ -100,7 +100,7 @@ def compute_content_hash(
|
||||||
|
|
||||||
|
|
||||||
def compute_event_reference_hash(
|
def compute_event_reference_hash(
|
||||||
event, hash_algorithm: Hasher = hashlib.sha256
|
event: EventBase, hash_algorithm: Hasher = hashlib.sha256
|
||||||
) -> Tuple[str, bytes]:
|
) -> Tuple[str, bytes]:
|
||||||
"""Computes the event reference hash. This is the hash of the redacted
|
"""Computes the event reference hash. This is the hash of the redacted
|
||||||
event.
|
event.
|
||||||
|
|
|
@ -87,7 +87,7 @@ class VerifyJsonRequest:
|
||||||
server_name: str,
|
server_name: str,
|
||||||
json_object: JsonDict,
|
json_object: JsonDict,
|
||||||
minimum_valid_until_ms: int,
|
minimum_valid_until_ms: int,
|
||||||
):
|
) -> "VerifyJsonRequest":
|
||||||
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON
|
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON
|
||||||
object for the given server.
|
object for the given server.
|
||||||
"""
|
"""
|
||||||
|
@ -104,7 +104,7 @@ class VerifyJsonRequest:
|
||||||
server_name: str,
|
server_name: str,
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
minimum_valid_until_ms: int,
|
minimum_valid_until_ms: int,
|
||||||
):
|
) -> "VerifyJsonRequest":
|
||||||
"""Create a VerifyJsonRequest to verify all signatures on an event
|
"""Create a VerifyJsonRequest to verify all signatures on an event
|
||||||
object for the given server.
|
object for the given server.
|
||||||
"""
|
"""
|
||||||
|
@ -449,7 +449,9 @@ class StoreKeyFetcher(KeyFetcher):
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]):
|
async def _fetch_keys(
|
||||||
|
self, keys_to_fetch: List[_FetchKeyRequest]
|
||||||
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
key_ids_to_fetch = (
|
key_ids_to_fetch = (
|
||||||
(queue_value.server_name, key_id)
|
(queue_value.server_name, key_id)
|
||||||
for queue_value in keys_to_fetch
|
for queue_value in keys_to_fetch
|
||||||
|
|
Loading…
Reference in New Issue