800 lines
		
	
	
		
			29 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			800 lines
		
	
	
		
			29 KiB
		
	
	
	
		
			Python
		
	
	
| # -*- coding: utf-8 -*-
 | |
| # Copyright 2014-2016 OpenMarket Ltd
 | |
| # Copyright 2017 New Vector Ltd.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import hashlib
 | |
| import logging
 | |
| import urllib
 | |
| from collections import namedtuple
 | |
| 
 | |
| from signedjson.key import (
 | |
|     decode_verify_key_bytes,
 | |
|     encode_verify_key_base64,
 | |
|     is_signing_algorithm_supported,
 | |
| )
 | |
| from signedjson.sign import (
 | |
|     SignatureVerifyException,
 | |
|     encode_canonical_json,
 | |
|     sign_json,
 | |
|     signature_ids,
 | |
|     verify_signed_json,
 | |
| )
 | |
| from unpaddedbase64 import decode_base64, encode_base64
 | |
| 
 | |
| from OpenSSL import crypto
 | |
| from twisted.internet import defer
 | |
| 
 | |
| from synapse.api.errors import Codes, SynapseError
 | |
| from synapse.crypto.keyclient import fetch_server_key
 | |
| from synapse.util import logcontext, unwrapFirstError
 | |
| from synapse.util.logcontext import (
 | |
|     PreserveLoggingContext,
 | |
|     preserve_fn,
 | |
|     run_in_background,
 | |
| )
 | |
| from synapse.util.metrics import Measure
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| VerifyKeyRequest = namedtuple("VerifyRequest", (
 | |
|     "server_name", "key_ids", "json_object", "deferred"
 | |
| ))
 | |
| """
 | |
| A request for a verify key to verify a JSON object.
 | |
| 
 | |
| Attributes:
 | |
|     server_name(str): The name of the server to verify against.
 | |
|     key_ids(set(str)): The set of key_ids to that could be used to verify the
 | |
|         JSON object
 | |
|     json_object(dict): The JSON object to verify.
 | |
|     deferred(Deferred[str, str, nacl.signing.VerifyKey]):
 | |
|         A deferred (server_name, key_id, verify_key) tuple that resolves when
 | |
|         a verify key has been fetched. The deferreds' callbacks are run with no
 | |
|         logcontext.
 | |
| """
 | |
| 
 | |
| 
 | |
| class KeyLookupError(ValueError):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class Keyring(object):
 | |
|     def __init__(self, hs):
 | |
|         self.store = hs.get_datastore()
 | |
|         self.clock = hs.get_clock()
 | |
|         self.client = hs.get_http_client()
 | |
|         self.config = hs.get_config()
 | |
|         self.perspective_servers = self.config.perspectives
 | |
|         self.hs = hs
 | |
| 
 | |
|         # map from server name to Deferred. Has an entry for each server with
 | |
|         # an ongoing key download; the Deferred completes once the download
 | |
|         # completes.
 | |
|         #
 | |
|         # These are regular, logcontext-agnostic Deferreds.
 | |
|         self.key_downloads = {}
 | |
| 
 | |
|     def verify_json_for_server(self, server_name, json_object):
 | |
|         return logcontext.make_deferred_yieldable(
 | |
|             self.verify_json_objects_for_server(
 | |
|                 [(server_name, json_object)]
 | |
|             )[0]
 | |
|         )
 | |
| 
 | |
|     def verify_json_objects_for_server(self, server_and_json):
 | |
|         """Bulk verifies signatures of json objects, bulk fetching keys as
 | |
|         necessary.
 | |
| 
 | |
|         Args:
 | |
|             server_and_json (list): List of pairs of (server_name, json_object)
 | |
| 
 | |
|         Returns:
 | |
|             List<Deferred>: for each input pair, a deferred indicating success
 | |
|                 or failure to verify each json object's signature for the given
 | |
|                 server_name. The deferreds run their callbacks in the sentinel
 | |
|                 logcontext.
 | |
|         """
 | |
|         verify_requests = []
 | |
| 
 | |
|         for server_name, json_object in server_and_json:
 | |
| 
 | |
|             key_ids = signature_ids(json_object, server_name)
 | |
|             if not key_ids:
 | |
|                 logger.warn("Request from %s: no supported signature keys",
 | |
|                             server_name)
 | |
|                 deferred = defer.fail(SynapseError(
 | |
|                     400,
 | |
|                     "Not signed with a supported algorithm",
 | |
|                     Codes.UNAUTHORIZED,
 | |
|                 ))
 | |
|             else:
 | |
|                 deferred = defer.Deferred()
 | |
| 
 | |
|             logger.debug("Verifying for %s with key_ids %s",
 | |
|                          server_name, key_ids)
 | |
| 
 | |
|             verify_request = VerifyKeyRequest(
 | |
|                 server_name, key_ids, json_object, deferred
 | |
|             )
 | |
| 
 | |
|             verify_requests.append(verify_request)
 | |
| 
 | |
|         run_in_background(self._start_key_lookups, verify_requests)
 | |
| 
 | |
|         # Pass those keys to handle_key_deferred so that the json object
 | |
|         # signatures can be verified
 | |
|         handle = preserve_fn(_handle_key_deferred)
 | |
|         return [
 | |
|             handle(rq) for rq in verify_requests
 | |
|         ]
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def _start_key_lookups(self, verify_requests):
 | |
|         """Sets off the key fetches for each verify request
 | |
| 
 | |
|         Once each fetch completes, verify_request.deferred will be resolved.
 | |
| 
 | |
|         Args:
 | |
|             verify_requests (List[VerifyKeyRequest]):
 | |
|         """
 | |
| 
 | |
|         try:
 | |
|             # create a deferred for each server we're going to look up the keys
 | |
|             # for; we'll resolve them once we have completed our lookups.
 | |
|             # These will be passed into wait_for_previous_lookups to block
 | |
|             # any other lookups until we have finished.
 | |
|             # The deferreds are called with no logcontext.
 | |
|             server_to_deferred = {
 | |
|                 rq.server_name: defer.Deferred()
 | |
|                 for rq in verify_requests
 | |
|             }
 | |
| 
 | |
|             # We want to wait for any previous lookups to complete before
 | |
|             # proceeding.
 | |
|             yield self.wait_for_previous_lookups(
 | |
|                 [rq.server_name for rq in verify_requests],
 | |
|                 server_to_deferred,
 | |
|             )
 | |
| 
 | |
|             # Actually start fetching keys.
 | |
|             self._get_server_verify_keys(verify_requests)
 | |
| 
 | |
|             # When we've finished fetching all the keys for a given server_name,
 | |
|             # resolve the deferred passed to `wait_for_previous_lookups` so that
 | |
|             # any lookups waiting will proceed.
 | |
|             #
 | |
|             # map from server name to a set of request ids
 | |
|             server_to_request_ids = {}
 | |
| 
 | |
|             for verify_request in verify_requests:
 | |
|                 server_name = verify_request.server_name
 | |
|                 request_id = id(verify_request)
 | |
|                 server_to_request_ids.setdefault(server_name, set()).add(request_id)
 | |
| 
 | |
|             def remove_deferreds(res, verify_request):
 | |
|                 server_name = verify_request.server_name
 | |
|                 request_id = id(verify_request)
 | |
|                 server_to_request_ids[server_name].discard(request_id)
 | |
|                 if not server_to_request_ids[server_name]:
 | |
|                     d = server_to_deferred.pop(server_name, None)
 | |
|                     if d:
 | |
|                         d.callback(None)
 | |
|                 return res
 | |
| 
 | |
|             for verify_request in verify_requests:
 | |
|                 verify_request.deferred.addBoth(
 | |
|                     remove_deferreds, verify_request,
 | |
|                 )
 | |
|         except Exception:
 | |
|             logger.exception("Error starting key lookups")
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def wait_for_previous_lookups(self, server_names, server_to_deferred):
 | |
|         """Waits for any previous key lookups for the given servers to finish.
 | |
| 
 | |
|         Args:
 | |
|             server_names (list): list of server_names we want to lookup
 | |
|             server_to_deferred (dict): server_name to deferred which gets
 | |
|                 resolved once we've finished looking up keys for that server.
 | |
|                 The Deferreds should be regular twisted ones which call their
 | |
|                 callbacks with no logcontext.
 | |
| 
 | |
|         Returns: a Deferred which resolves once all key lookups for the given
 | |
|             servers have completed. Follows the synapse rules of logcontext
 | |
|             preservation.
 | |
|         """
 | |
|         while True:
 | |
|             wait_on = [
 | |
|                 self.key_downloads[server_name]
 | |
|                 for server_name in server_names
 | |
|                 if server_name in self.key_downloads
 | |
|             ]
 | |
|             if wait_on:
 | |
|                 with PreserveLoggingContext():
 | |
|                     yield defer.DeferredList(wait_on)
 | |
|             else:
 | |
|                 break
 | |
| 
 | |
|         def rm(r, server_name_):
 | |
|             self.key_downloads.pop(server_name_, None)
 | |
|             return r
 | |
| 
 | |
|         for server_name, deferred in server_to_deferred.items():
 | |
|             self.key_downloads[server_name] = deferred
 | |
|             deferred.addBoth(rm, server_name)
 | |
| 
 | |
|     def _get_server_verify_keys(self, verify_requests):
 | |
|         """Tries to find at least one key for each verify request
 | |
| 
 | |
|         For each verify_request, verify_request.deferred is called back with
 | |
|         params (server_name, key_id, VerifyKey) if a key is found, or errbacked
 | |
|         with a SynapseError if none of the keys are found.
 | |
| 
 | |
|         Args:
 | |
|             verify_requests (list[VerifyKeyRequest]): list of verify requests
 | |
|         """
 | |
| 
 | |
|         # These are functions that produce keys given a list of key ids
 | |
|         key_fetch_fns = (
 | |
|             self.get_keys_from_store,  # First try the local store
 | |
|             self.get_keys_from_perspectives,  # Then try via perspectives
 | |
|             self.get_keys_from_server,  # Then try directly
 | |
|         )
 | |
| 
 | |
|         @defer.inlineCallbacks
 | |
|         def do_iterations():
 | |
|             with Measure(self.clock, "get_server_verify_keys"):
 | |
|                 # dict[str, dict[str, VerifyKey]]: results so far.
 | |
|                 # map server_name -> key_id -> VerifyKey
 | |
|                 merged_results = {}
 | |
| 
 | |
|                 # dict[str, set(str)]: keys to fetch for each server
 | |
|                 missing_keys = {}
 | |
|                 for verify_request in verify_requests:
 | |
|                     missing_keys.setdefault(verify_request.server_name, set()).update(
 | |
|                         verify_request.key_ids
 | |
|                     )
 | |
| 
 | |
|                 for fn in key_fetch_fns:
 | |
|                     results = yield fn(missing_keys.items())
 | |
|                     merged_results.update(results)
 | |
| 
 | |
|                     # We now need to figure out which verify requests we have keys
 | |
|                     # for and which we don't
 | |
|                     missing_keys = {}
 | |
|                     requests_missing_keys = []
 | |
|                     for verify_request in verify_requests:
 | |
|                         server_name = verify_request.server_name
 | |
|                         result_keys = merged_results[server_name]
 | |
| 
 | |
|                         if verify_request.deferred.called:
 | |
|                             # We've already called this deferred, which probably
 | |
|                             # means that we've already found a key for it.
 | |
|                             continue
 | |
| 
 | |
|                         for key_id in verify_request.key_ids:
 | |
|                             if key_id in result_keys:
 | |
|                                 with PreserveLoggingContext():
 | |
|                                     verify_request.deferred.callback((
 | |
|                                         server_name,
 | |
|                                         key_id,
 | |
|                                         result_keys[key_id],
 | |
|                                     ))
 | |
|                                 break
 | |
|                         else:
 | |
|                             # The else block is only reached if the loop above
 | |
|                             # doesn't break.
 | |
|                             missing_keys.setdefault(server_name, set()).update(
 | |
|                                 verify_request.key_ids
 | |
|                             )
 | |
|                             requests_missing_keys.append(verify_request)
 | |
| 
 | |
|                     if not missing_keys:
 | |
|                         break
 | |
| 
 | |
|                 with PreserveLoggingContext():
 | |
|                     for verify_request in requests_missing_keys:
 | |
|                         verify_request.deferred.errback(SynapseError(
 | |
|                             401,
 | |
|                             "No key for %s with id %s" % (
 | |
|                                 verify_request.server_name, verify_request.key_ids,
 | |
|                             ),
 | |
|                             Codes.UNAUTHORIZED,
 | |
|                         ))
 | |
| 
 | |
|         def on_err(err):
 | |
|             with PreserveLoggingContext():
 | |
|                 for verify_request in verify_requests:
 | |
|                     if not verify_request.deferred.called:
 | |
|                         verify_request.deferred.errback(err)
 | |
| 
 | |
|         run_in_background(do_iterations).addErrback(on_err)
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def get_keys_from_store(self, server_name_and_key_ids):
 | |
|         """
 | |
| 
 | |
|         Args:
 | |
|             server_name_and_key_ids (list[(str, iterable[str])]):
 | |
|                 list of (server_name, iterable[key_id]) tuples to fetch keys for
 | |
| 
 | |
|         Returns:
 | |
|             Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
 | |
|                 server_name -> key_id -> VerifyKey
 | |
|         """
 | |
|         res = yield logcontext.make_deferred_yieldable(defer.gatherResults(
 | |
|             [
 | |
|                 run_in_background(
 | |
|                     self.store.get_server_verify_keys,
 | |
|                     server_name, key_ids,
 | |
|                 ).addCallback(lambda ks, server: (server, ks), server_name)
 | |
|                 for server_name, key_ids in server_name_and_key_ids
 | |
|             ],
 | |
|             consumeErrors=True,
 | |
|         ).addErrback(unwrapFirstError))
 | |
| 
 | |
|         defer.returnValue(dict(res))
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def get_keys_from_perspectives(self, server_name_and_key_ids):
 | |
|         @defer.inlineCallbacks
 | |
|         def get_key(perspective_name, perspective_keys):
 | |
|             try:
 | |
|                 result = yield self.get_server_verify_key_v2_indirect(
 | |
|                     server_name_and_key_ids, perspective_name, perspective_keys
 | |
|                 )
 | |
|                 defer.returnValue(result)
 | |
|             except Exception as e:
 | |
|                 logger.exception(
 | |
|                     "Unable to get key from %r: %s %s",
 | |
|                     perspective_name,
 | |
|                     type(e).__name__, str(e),
 | |
|                 )
 | |
|                 defer.returnValue({})
 | |
| 
 | |
|         results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
 | |
|             [
 | |
|                 run_in_background(get_key, p_name, p_keys)
 | |
|                 for p_name, p_keys in self.perspective_servers.items()
 | |
|             ],
 | |
|             consumeErrors=True,
 | |
|         ).addErrback(unwrapFirstError))
 | |
| 
 | |
|         union_of_keys = {}
 | |
|         for result in results:
 | |
|             for server_name, keys in result.items():
 | |
|                 union_of_keys.setdefault(server_name, {}).update(keys)
 | |
| 
 | |
|         defer.returnValue(union_of_keys)
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def get_keys_from_server(self, server_name_and_key_ids):
 | |
|         @defer.inlineCallbacks
 | |
|         def get_key(server_name, key_ids):
 | |
|             keys = None
 | |
|             try:
 | |
|                 keys = yield self.get_server_verify_key_v2_direct(
 | |
|                     server_name, key_ids
 | |
|                 )
 | |
|             except Exception as e:
 | |
|                 logger.info(
 | |
|                     "Unable to get key %r for %r directly: %s %s",
 | |
|                     key_ids, server_name,
 | |
|                     type(e).__name__, str(e),
 | |
|                 )
 | |
| 
 | |
|             if not keys:
 | |
|                 keys = yield self.get_server_verify_key_v1_direct(
 | |
|                     server_name, key_ids
 | |
|                 )
 | |
| 
 | |
|                 keys = {server_name: keys}
 | |
| 
 | |
|             defer.returnValue(keys)
 | |
| 
 | |
|         results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
 | |
|             [
 | |
|                 run_in_background(get_key, server_name, key_ids)
 | |
|                 for server_name, key_ids in server_name_and_key_ids
 | |
|             ],
 | |
|             consumeErrors=True,
 | |
|         ).addErrback(unwrapFirstError))
 | |
| 
 | |
|         merged = {}
 | |
|         for result in results:
 | |
|             merged.update(result)
 | |
| 
 | |
|         defer.returnValue({
 | |
|             server_name: keys
 | |
|             for server_name, keys in merged.items()
 | |
|             if keys
 | |
|         })
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
 | |
|                                           perspective_name,
 | |
|                                           perspective_keys):
 | |
|         # TODO(mark): Set the minimum_valid_until_ts to that needed by
 | |
|         # the events being validated or the current time if validating
 | |
|         # an incoming request.
 | |
|         query_response = yield self.client.post_json(
 | |
|             destination=perspective_name,
 | |
|             path=b"/_matrix/key/v2/query",
 | |
|             data={
 | |
|                 u"server_keys": {
 | |
|                     server_name: {
 | |
|                         key_id: {
 | |
|                             u"minimum_valid_until_ts": 0
 | |
|                         } for key_id in key_ids
 | |
|                     }
 | |
|                     for server_name, key_ids in server_names_and_key_ids
 | |
|                 }
 | |
|             },
 | |
|             long_retries=True,
 | |
|         )
 | |
| 
 | |
|         keys = {}
 | |
| 
 | |
|         responses = query_response["server_keys"]
 | |
| 
 | |
|         for response in responses:
 | |
|             if (u"signatures" not in response
 | |
|                     or perspective_name not in response[u"signatures"]):
 | |
|                 raise KeyLookupError(
 | |
|                     "Key response not signed by perspective server"
 | |
|                     " %r" % (perspective_name,)
 | |
|                 )
 | |
| 
 | |
|             verified = False
 | |
|             for key_id in response[u"signatures"][perspective_name]:
 | |
|                 if key_id in perspective_keys:
 | |
|                     verify_signed_json(
 | |
|                         response,
 | |
|                         perspective_name,
 | |
|                         perspective_keys[key_id]
 | |
|                     )
 | |
|                     verified = True
 | |
| 
 | |
|             if not verified:
 | |
|                 logging.info(
 | |
|                     "Response from perspective server %r not signed with a"
 | |
|                     " known key, signed with: %r, known keys: %r",
 | |
|                     perspective_name,
 | |
|                     list(response[u"signatures"][perspective_name]),
 | |
|                     list(perspective_keys)
 | |
|                 )
 | |
|                 raise KeyLookupError(
 | |
|                     "Response not signed with a known key for perspective"
 | |
|                     " server %r" % (perspective_name,)
 | |
|                 )
 | |
| 
 | |
|             processed_response = yield self.process_v2_response(
 | |
|                 perspective_name, response, only_from_server=False
 | |
|             )
 | |
| 
 | |
|             for server_name, response_keys in processed_response.items():
 | |
|                 keys.setdefault(server_name, {}).update(response_keys)
 | |
| 
 | |
|         yield logcontext.make_deferred_yieldable(defer.gatherResults(
 | |
|             [
 | |
|                 run_in_background(
 | |
|                     self.store_keys,
 | |
|                     server_name=server_name,
 | |
|                     from_server=perspective_name,
 | |
|                     verify_keys=response_keys,
 | |
|                 )
 | |
|                 for server_name, response_keys in keys.items()
 | |
|             ],
 | |
|             consumeErrors=True
 | |
|         ).addErrback(unwrapFirstError))
 | |
| 
 | |
|         defer.returnValue(keys)
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def get_server_verify_key_v2_direct(self, server_name, key_ids):
 | |
|         keys = {}
 | |
| 
 | |
|         for requested_key_id in key_ids:
 | |
|             if requested_key_id in keys:
 | |
|                 continue
 | |
| 
 | |
|             (response, tls_certificate) = yield fetch_server_key(
 | |
|                 server_name, self.hs.tls_server_context_factory,
 | |
|                 path=(b"/_matrix/key/v2/server/%s" % (
 | |
|                     urllib.quote(requested_key_id),
 | |
|                 )).encode("ascii"),
 | |
|             )
 | |
| 
 | |
|             if (u"signatures" not in response
 | |
|                     or server_name not in response[u"signatures"]):
 | |
|                 raise KeyLookupError("Key response not signed by remote server")
 | |
| 
 | |
|             if "tls_fingerprints" not in response:
 | |
|                 raise KeyLookupError("Key response missing TLS fingerprints")
 | |
| 
 | |
|             certificate_bytes = crypto.dump_certificate(
 | |
|                 crypto.FILETYPE_ASN1, tls_certificate
 | |
|             )
 | |
|             sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
 | |
|             sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
 | |
| 
 | |
|             response_sha256_fingerprints = set()
 | |
|             for fingerprint in response[u"tls_fingerprints"]:
 | |
|                 if u"sha256" in fingerprint:
 | |
|                     response_sha256_fingerprints.add(fingerprint[u"sha256"])
 | |
| 
 | |
|             if sha256_fingerprint_b64 not in response_sha256_fingerprints:
 | |
|                 raise KeyLookupError("TLS certificate not allowed by fingerprints")
 | |
| 
 | |
|             response_keys = yield self.process_v2_response(
 | |
|                 from_server=server_name,
 | |
|                 requested_ids=[requested_key_id],
 | |
|                 response_json=response,
 | |
|             )
 | |
| 
 | |
|             keys.update(response_keys)
 | |
| 
 | |
|         yield logcontext.make_deferred_yieldable(defer.gatherResults(
 | |
|             [
 | |
|                 run_in_background(
 | |
|                     self.store_keys,
 | |
|                     server_name=key_server_name,
 | |
|                     from_server=server_name,
 | |
|                     verify_keys=verify_keys,
 | |
|                 )
 | |
|                 for key_server_name, verify_keys in keys.items()
 | |
|             ],
 | |
|             consumeErrors=True
 | |
|         ).addErrback(unwrapFirstError))
 | |
| 
 | |
|         defer.returnValue(keys)
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def process_v2_response(self, from_server, response_json,
 | |
|                             requested_ids=[], only_from_server=True):
 | |
|         time_now_ms = self.clock.time_msec()
 | |
|         response_keys = {}
 | |
|         verify_keys = {}
 | |
|         for key_id, key_data in response_json["verify_keys"].items():
 | |
|             if is_signing_algorithm_supported(key_id):
 | |
|                 key_base64 = key_data["key"]
 | |
|                 key_bytes = decode_base64(key_base64)
 | |
|                 verify_key = decode_verify_key_bytes(key_id, key_bytes)
 | |
|                 verify_key.time_added = time_now_ms
 | |
|                 verify_keys[key_id] = verify_key
 | |
| 
 | |
|         old_verify_keys = {}
 | |
|         for key_id, key_data in response_json["old_verify_keys"].items():
 | |
|             if is_signing_algorithm_supported(key_id):
 | |
|                 key_base64 = key_data["key"]
 | |
|                 key_bytes = decode_base64(key_base64)
 | |
|                 verify_key = decode_verify_key_bytes(key_id, key_bytes)
 | |
|                 verify_key.expired = key_data["expired_ts"]
 | |
|                 verify_key.time_added = time_now_ms
 | |
|                 old_verify_keys[key_id] = verify_key
 | |
| 
 | |
|         results = {}
 | |
|         server_name = response_json["server_name"]
 | |
|         if only_from_server:
 | |
|             if server_name != from_server:
 | |
|                 raise KeyLookupError(
 | |
|                     "Expected a response for server %r not %r" % (
 | |
|                         from_server, server_name
 | |
|                     )
 | |
|                 )
 | |
|         for key_id in response_json["signatures"].get(server_name, {}):
 | |
|             if key_id not in response_json["verify_keys"]:
 | |
|                 raise KeyLookupError(
 | |
|                     "Key response must include verification keys for all"
 | |
|                     " signatures"
 | |
|                 )
 | |
|             if key_id in verify_keys:
 | |
|                 verify_signed_json(
 | |
|                     response_json,
 | |
|                     server_name,
 | |
|                     verify_keys[key_id]
 | |
|                 )
 | |
| 
 | |
|         signed_key_json = sign_json(
 | |
|             response_json,
 | |
|             self.config.server_name,
 | |
|             self.config.signing_key[0],
 | |
|         )
 | |
| 
 | |
|         signed_key_json_bytes = encode_canonical_json(signed_key_json)
 | |
|         ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
 | |
| 
 | |
|         updated_key_ids = set(requested_ids)
 | |
|         updated_key_ids.update(verify_keys)
 | |
|         updated_key_ids.update(old_verify_keys)
 | |
| 
 | |
|         response_keys.update(verify_keys)
 | |
|         response_keys.update(old_verify_keys)
 | |
| 
 | |
|         yield logcontext.make_deferred_yieldable(defer.gatherResults(
 | |
|             [
 | |
|                 run_in_background(
 | |
|                     self.store.store_server_keys_json,
 | |
|                     server_name=server_name,
 | |
|                     key_id=key_id,
 | |
|                     from_server=server_name,
 | |
|                     ts_now_ms=time_now_ms,
 | |
|                     ts_expires_ms=ts_valid_until_ms,
 | |
|                     key_json_bytes=signed_key_json_bytes,
 | |
|                 )
 | |
|                 for key_id in updated_key_ids
 | |
|             ],
 | |
|             consumeErrors=True,
 | |
|         ).addErrback(unwrapFirstError))
 | |
| 
 | |
|         results[server_name] = response_keys
 | |
| 
 | |
|         defer.returnValue(results)
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def get_server_verify_key_v1_direct(self, server_name, key_ids):
 | |
|         """Finds a verification key for the server with one of the key ids.
 | |
|         Args:
 | |
|             server_name (str): The name of the server to fetch a key for.
 | |
|             keys_ids (list of str): The key_ids to check for.
 | |
|         """
 | |
| 
 | |
|         # Try to fetch the key from the remote server.
 | |
| 
 | |
|         (response, tls_certificate) = yield fetch_server_key(
 | |
|             server_name, self.hs.tls_server_context_factory
 | |
|         )
 | |
| 
 | |
|         # Check the response.
 | |
| 
 | |
|         x509_certificate_bytes = crypto.dump_certificate(
 | |
|             crypto.FILETYPE_ASN1, tls_certificate
 | |
|         )
 | |
| 
 | |
|         if ("signatures" not in response
 | |
|                 or server_name not in response["signatures"]):
 | |
|             raise KeyLookupError("Key response not signed by remote server")
 | |
| 
 | |
|         if "tls_certificate" not in response:
 | |
|             raise KeyLookupError("Key response missing TLS certificate")
 | |
| 
 | |
|         tls_certificate_b64 = response["tls_certificate"]
 | |
| 
 | |
|         if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
 | |
|             raise KeyLookupError("TLS certificate doesn't match")
 | |
| 
 | |
|         # Cache the result in the datastore.
 | |
| 
 | |
|         time_now_ms = self.clock.time_msec()
 | |
| 
 | |
|         verify_keys = {}
 | |
|         for key_id, key_base64 in response["verify_keys"].items():
 | |
|             if is_signing_algorithm_supported(key_id):
 | |
|                 key_bytes = decode_base64(key_base64)
 | |
|                 verify_key = decode_verify_key_bytes(key_id, key_bytes)
 | |
|                 verify_key.time_added = time_now_ms
 | |
|                 verify_keys[key_id] = verify_key
 | |
| 
 | |
|         for key_id in response["signatures"][server_name]:
 | |
|             if key_id not in response["verify_keys"]:
 | |
|                 raise KeyLookupError(
 | |
|                     "Key response must include verification keys for all"
 | |
|                     " signatures"
 | |
|                 )
 | |
|             if key_id in verify_keys:
 | |
|                 verify_signed_json(
 | |
|                     response,
 | |
|                     server_name,
 | |
|                     verify_keys[key_id]
 | |
|                 )
 | |
| 
 | |
|         yield self.store.store_server_certificate(
 | |
|             server_name,
 | |
|             server_name,
 | |
|             time_now_ms,
 | |
|             tls_certificate,
 | |
|         )
 | |
| 
 | |
|         yield self.store_keys(
 | |
|             server_name=server_name,
 | |
|             from_server=server_name,
 | |
|             verify_keys=verify_keys,
 | |
|         )
 | |
| 
 | |
|         defer.returnValue(verify_keys)
 | |
| 
 | |
|     def store_keys(self, server_name, from_server, verify_keys):
 | |
|         """Store a collection of verify keys for a given server
 | |
|         Args:
 | |
|             server_name(str): The name of the server the keys are for.
 | |
|             from_server(str): The server the keys were downloaded from.
 | |
|             verify_keys(dict): A mapping of key_id to VerifyKey.
 | |
|         Returns:
 | |
|             A deferred that completes when the keys are stored.
 | |
|         """
 | |
|         # TODO(markjh): Store whether the keys have expired.
 | |
|         return logcontext.make_deferred_yieldable(defer.gatherResults(
 | |
|             [
 | |
|                 run_in_background(
 | |
|                     self.store.store_server_verify_key,
 | |
|                     server_name, server_name, key.time_added, key
 | |
|                 )
 | |
|                 for key_id, key in verify_keys.items()
 | |
|             ],
 | |
|             consumeErrors=True,
 | |
|         ).addErrback(unwrapFirstError))
 | |
| 
 | |
| 
 | |
| @defer.inlineCallbacks
 | |
| def _handle_key_deferred(verify_request):
 | |
|     """Waits for the key to become available, and then performs a verification
 | |
| 
 | |
|     Args:
 | |
|         verify_request (VerifyKeyRequest):
 | |
| 
 | |
|     Returns:
 | |
|         Deferred[None]
 | |
| 
 | |
|     Raises:
 | |
|         SynapseError if there was a problem performing the verification
 | |
|     """
 | |
|     server_name = verify_request.server_name
 | |
|     try:
 | |
|         with PreserveLoggingContext():
 | |
|             _, key_id, verify_key = yield verify_request.deferred
 | |
|     except IOError as e:
 | |
|         logger.warn(
 | |
|             "Got IOError when downloading keys for %s: %s %s",
 | |
|             server_name, type(e).__name__, str(e),
 | |
|         )
 | |
|         raise SynapseError(
 | |
|             502,
 | |
|             "Error downloading keys for %s" % (server_name,),
 | |
|             Codes.UNAUTHORIZED,
 | |
|         )
 | |
|     except Exception as e:
 | |
|         logger.exception(
 | |
|             "Got Exception when downloading keys for %s: %s %s",
 | |
|             server_name, type(e).__name__, str(e),
 | |
|         )
 | |
|         raise SynapseError(
 | |
|             401,
 | |
|             "No key for %s with id %s" % (server_name, verify_request.key_ids),
 | |
|             Codes.UNAUTHORIZED,
 | |
|         )
 | |
| 
 | |
|     json_object = verify_request.json_object
 | |
| 
 | |
|     logger.debug("Got key %s %s:%s for server %s, verifying" % (
 | |
|         key_id, verify_key.alg, verify_key.version, server_name,
 | |
|     ))
 | |
|     try:
 | |
|         verify_signed_json(json_object, server_name, verify_key)
 | |
|     except SignatureVerifyException as e:
 | |
|         logger.debug(
 | |
|             "Error verifying signature for %s:%s:%s with key %s: %s",
 | |
|             server_name, verify_key.alg, verify_key.version,
 | |
|             encode_verify_key_base64(verify_key),
 | |
|             str(e),
 | |
|         )
 | |
|         raise SynapseError(
 | |
|             401,
 | |
|             "Invalid signature for server %s with key %s:%s: %s" % (
 | |
|                 server_name, verify_key.alg, verify_key.version, str(e),
 | |
|             ),
 | |
|             Codes.UNAUTHORIZED,
 | |
|         )
 |