863 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			863 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
| # -*- coding: utf-8 -*-
 | |
| # Copyright 2014-2016 OpenMarket Ltd
 | |
| # Copyright 2017, 2018 New Vector Ltd
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import logging
 | |
| import urllib
 | |
| from collections import defaultdict
 | |
| 
 | |
| import attr
 | |
| from signedjson.key import (
 | |
|     decode_verify_key_bytes,
 | |
|     encode_verify_key_base64,
 | |
|     is_signing_algorithm_supported,
 | |
| )
 | |
| from signedjson.sign import (
 | |
|     SignatureVerifyException,
 | |
|     encode_canonical_json,
 | |
|     signature_ids,
 | |
|     verify_signed_json,
 | |
| )
 | |
| from unpaddedbase64 import decode_base64
 | |
| 
 | |
| from twisted.internet import defer
 | |
| 
 | |
| from synapse.api.errors import (
 | |
|     Codes,
 | |
|     HttpResponseException,
 | |
|     RequestSendFailed,
 | |
|     SynapseError,
 | |
| )
 | |
| from synapse.logging.context import (
 | |
|     PreserveLoggingContext,
 | |
|     current_context,
 | |
|     make_deferred_yieldable,
 | |
|     preserve_fn,
 | |
|     run_in_background,
 | |
| )
 | |
| from synapse.storage.keys import FetchKeyResult
 | |
| from synapse.util import unwrapFirstError
 | |
| from synapse.util.async_helpers import yieldable_gather_results
 | |
| from synapse.util.metrics import Measure
 | |
| from synapse.util.retryutils import NotRetryingDestination
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| @attr.s(slots=True, cmp=False)
 | |
| class VerifyJsonRequest(object):
 | |
|     """
 | |
|     A request to verify a JSON object.
 | |
| 
 | |
|     Attributes:
 | |
|         server_name(str): The name of the server to verify against.
 | |
| 
 | |
|         key_ids(set[str]): The set of key_ids to that could be used to verify the
 | |
|             JSON object
 | |
| 
 | |
|         json_object(dict): The JSON object to verify.
 | |
| 
 | |
|         minimum_valid_until_ts (int): time at which we require the signing key to
 | |
|             be valid. (0 implies we don't care)
 | |
| 
 | |
|         key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
 | |
|             A deferred (server_name, key_id, verify_key) tuple that resolves when
 | |
|             a verify key has been fetched. The deferreds' callbacks are run with no
 | |
|             logcontext.
 | |
| 
 | |
|             If we are unable to find a key which satisfies the request, the deferred
 | |
|             errbacks with an M_UNAUTHORIZED SynapseError.
 | |
|     """
 | |
| 
 | |
|     server_name = attr.ib()
 | |
|     json_object = attr.ib()
 | |
|     minimum_valid_until_ts = attr.ib()
 | |
|     request_name = attr.ib()
 | |
|     key_ids = attr.ib(init=False)
 | |
|     key_ready = attr.ib(default=attr.Factory(defer.Deferred))
 | |
| 
 | |
|     def __attrs_post_init__(self):
 | |
|         self.key_ids = signature_ids(self.json_object, self.server_name)
 | |
| 
 | |
| 
 | |
| class KeyLookupError(ValueError):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class Keyring(object):
 | |
|     def __init__(self, hs, key_fetchers=None):
 | |
|         self.clock = hs.get_clock()
 | |
| 
 | |
|         if key_fetchers is None:
 | |
|             key_fetchers = (
 | |
|                 StoreKeyFetcher(hs),
 | |
|                 PerspectivesKeyFetcher(hs),
 | |
|                 ServerKeyFetcher(hs),
 | |
|             )
 | |
|         self._key_fetchers = key_fetchers
 | |
| 
 | |
|         # map from server name to Deferred. Has an entry for each server with
 | |
|         # an ongoing key download; the Deferred completes once the download
 | |
|         # completes.
 | |
|         #
 | |
|         # These are regular, logcontext-agnostic Deferreds.
 | |
|         self.key_downloads = {}
 | |
| 
 | |
|     def verify_json_for_server(
 | |
|         self, server_name, json_object, validity_time, request_name
 | |
|     ):
 | |
|         """Verify that a JSON object has been signed by a given server
 | |
| 
 | |
|         Args:
 | |
|             server_name (str): name of the server which must have signed this object
 | |
| 
 | |
|             json_object (dict): object to be checked
 | |
| 
 | |
|             validity_time (int): timestamp at which we require the signing key to
 | |
|                 be valid. (0 implies we don't care)
 | |
| 
 | |
|             request_name (str): an identifier for this json object (eg, an event id)
 | |
|                 for logging.
 | |
| 
 | |
|         Returns:
 | |
|             Deferred[None]: completes if the the object was correctly signed, otherwise
 | |
|                 errbacks with an error
 | |
|         """
 | |
|         req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
 | |
|         requests = (req,)
 | |
|         return make_deferred_yieldable(self._verify_objects(requests)[0])
 | |
| 
 | |
|     def verify_json_objects_for_server(self, server_and_json):
 | |
|         """Bulk verifies signatures of json objects, bulk fetching keys as
 | |
|         necessary.
 | |
| 
 | |
|         Args:
 | |
|             server_and_json (iterable[Tuple[str, dict, int, str]):
 | |
|                 Iterable of (server_name, json_object, validity_time, request_name)
 | |
|                 tuples.
 | |
| 
 | |
|                 validity_time is a timestamp at which the signing key must be
 | |
|                 valid.
 | |
| 
 | |
|                 request_name is an identifier for this json object (eg, an event id)
 | |
|                 for logging.
 | |
| 
 | |
|         Returns:
 | |
|             List<Deferred[None]>: for each input triplet, a deferred indicating success
 | |
|                 or failure to verify each json object's signature for the given
 | |
|                 server_name. The deferreds run their callbacks in the sentinel
 | |
|                 logcontext.
 | |
|         """
 | |
|         return self._verify_objects(
 | |
|             VerifyJsonRequest(server_name, json_object, validity_time, request_name)
 | |
|             for server_name, json_object, validity_time, request_name in server_and_json
 | |
|         )
 | |
| 
 | |
|     def _verify_objects(self, verify_requests):
 | |
|         """Does the work of verify_json_[objects_]for_server
 | |
| 
 | |
| 
 | |
|         Args:
 | |
|             verify_requests (iterable[VerifyJsonRequest]):
 | |
|                 Iterable of verification requests.
 | |
| 
 | |
|         Returns:
 | |
|             List<Deferred[None]>: for each input item, a deferred indicating success
 | |
|                 or failure to verify each json object's signature for the given
 | |
|                 server_name. The deferreds run their callbacks in the sentinel
 | |
|                 logcontext.
 | |
|         """
 | |
|         # a list of VerifyJsonRequests which are awaiting a key lookup
 | |
|         key_lookups = []
 | |
|         handle = preserve_fn(_handle_key_deferred)
 | |
| 
 | |
|         def process(verify_request):
 | |
|             """Process an entry in the request list
 | |
| 
 | |
|             Adds a key request to key_lookups, and returns a deferred which
 | |
|             will complete or fail (in the sentinel context) when verification completes.
 | |
|             """
 | |
|             if not verify_request.key_ids:
 | |
|                 return defer.fail(
 | |
|                     SynapseError(
 | |
|                         400,
 | |
|                         "Not signed by %s" % (verify_request.server_name,),
 | |
|                         Codes.UNAUTHORIZED,
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|             logger.debug(
 | |
|                 "Verifying %s for %s with key_ids %s, min_validity %i",
 | |
|                 verify_request.request_name,
 | |
|                 verify_request.server_name,
 | |
|                 verify_request.key_ids,
 | |
|                 verify_request.minimum_valid_until_ts,
 | |
|             )
 | |
| 
 | |
|             # add the key request to the queue, but don't start it off yet.
 | |
|             key_lookups.append(verify_request)
 | |
| 
 | |
|             # now run _handle_key_deferred, which will wait for the key request
 | |
|             # to complete and then do the verification.
 | |
|             #
 | |
|             # We want _handle_key_request to log to the right context, so we
 | |
|             # wrap it with preserve_fn (aka run_in_background)
 | |
|             return handle(verify_request)
 | |
| 
 | |
|         results = [process(r) for r in verify_requests]
 | |
| 
 | |
|         if key_lookups:
 | |
|             run_in_background(self._start_key_lookups, key_lookups)
 | |
| 
 | |
|         return results
 | |
| 
 | |
|     async def _start_key_lookups(self, verify_requests):
 | |
|         """Sets off the key fetches for each verify request
 | |
| 
 | |
|         Once each fetch completes, verify_request.key_ready will be resolved.
 | |
| 
 | |
|         Args:
 | |
|             verify_requests (List[VerifyJsonRequest]):
 | |
|         """
 | |
| 
 | |
|         try:
 | |
|             ctx = current_context()
 | |
| 
 | |
|             # map from server name to a set of outstanding request ids
 | |
|             server_to_request_ids = {}
 | |
| 
 | |
|             for verify_request in verify_requests:
 | |
|                 server_name = verify_request.server_name
 | |
|                 request_id = id(verify_request)
 | |
|                 server_to_request_ids.setdefault(server_name, set()).add(request_id)
 | |
| 
 | |
|             # Wait for any previous lookups to complete before proceeding.
 | |
|             await self.wait_for_previous_lookups(server_to_request_ids.keys())
 | |
| 
 | |
|             # take out a lock on each of the servers by sticking a Deferred in
 | |
|             # key_downloads
 | |
|             for server_name in server_to_request_ids.keys():
 | |
|                 self.key_downloads[server_name] = defer.Deferred()
 | |
|                 logger.debug("Got key lookup lock on %s", server_name)
 | |
| 
 | |
|             # When we've finished fetching all the keys for a given server_name,
 | |
|             # drop the lock by resolving the deferred in key_downloads.
 | |
|             def drop_server_lock(server_name):
 | |
|                 d = self.key_downloads.pop(server_name)
 | |
|                 d.callback(None)
 | |
| 
 | |
|             def lookup_done(res, verify_request):
 | |
|                 server_name = verify_request.server_name
 | |
|                 server_requests = server_to_request_ids[server_name]
 | |
|                 server_requests.remove(id(verify_request))
 | |
| 
 | |
|                 # if there are no more requests for this server, we can drop the lock.
 | |
|                 if not server_requests:
 | |
|                     with PreserveLoggingContext(ctx):
 | |
|                         logger.debug("Releasing key lookup lock on %s", server_name)
 | |
| 
 | |
|                     # ... but not immediately, as that can cause stack explosions if
 | |
|                     # we get a long queue of lookups.
 | |
|                     self.clock.call_later(0, drop_server_lock, server_name)
 | |
| 
 | |
|                 return res
 | |
| 
 | |
|             for verify_request in verify_requests:
 | |
|                 verify_request.key_ready.addBoth(lookup_done, verify_request)
 | |
| 
 | |
|             # Actually start fetching keys.
 | |
|             self._get_server_verify_keys(verify_requests)
 | |
|         except Exception:
 | |
|             logger.exception("Error starting key lookups")
 | |
| 
 | |
|     async def wait_for_previous_lookups(self, server_names) -> None:
 | |
|         """Waits for any previous key lookups for the given servers to finish.
 | |
| 
 | |
|         Args:
 | |
|             server_names (Iterable[str]): list of servers which we want to look up
 | |
| 
 | |
|         Returns:
 | |
|             Resolves once all key lookups for the given servers have
 | |
|                 completed. Follows the synapse rules of logcontext preservation.
 | |
|         """
 | |
|         loop_count = 1
 | |
|         while True:
 | |
|             wait_on = [
 | |
|                 (server_name, self.key_downloads[server_name])
 | |
|                 for server_name in server_names
 | |
|                 if server_name in self.key_downloads
 | |
|             ]
 | |
|             if not wait_on:
 | |
|                 break
 | |
|             logger.info(
 | |
|                 "Waiting for existing lookups for %s to complete [loop %i]",
 | |
|                 [w[0] for w in wait_on],
 | |
|                 loop_count,
 | |
|             )
 | |
|             with PreserveLoggingContext():
 | |
|                 await defer.DeferredList((w[1] for w in wait_on))
 | |
| 
 | |
|             loop_count += 1
 | |
| 
 | |
|     def _get_server_verify_keys(self, verify_requests):
 | |
|         """Tries to find at least one key for each verify request
 | |
| 
 | |
|         For each verify_request, verify_request.key_ready is called back with
 | |
|         params (server_name, key_id, VerifyKey) if a key is found, or errbacked
 | |
|         with a SynapseError if none of the keys are found.
 | |
| 
 | |
|         Args:
 | |
|             verify_requests (list[VerifyJsonRequest]): list of verify requests
 | |
|         """
 | |
| 
 | |
|         remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
 | |
| 
 | |
|         async def do_iterations():
 | |
|             try:
 | |
|                 with Measure(self.clock, "get_server_verify_keys"):
 | |
|                     for f in self._key_fetchers:
 | |
|                         if not remaining_requests:
 | |
|                             return
 | |
|                         await self._attempt_key_fetches_with_fetcher(
 | |
|                             f, remaining_requests
 | |
|                         )
 | |
| 
 | |
|                     # look for any requests which weren't satisfied
 | |
|                     with PreserveLoggingContext():
 | |
|                         for verify_request in remaining_requests:
 | |
|                             verify_request.key_ready.errback(
 | |
|                                 SynapseError(
 | |
|                                     401,
 | |
|                                     "No key for %s with ids in %s (min_validity %i)"
 | |
|                                     % (
 | |
|                                         verify_request.server_name,
 | |
|                                         verify_request.key_ids,
 | |
|                                         verify_request.minimum_valid_until_ts,
 | |
|                                     ),
 | |
|                                     Codes.UNAUTHORIZED,
 | |
|                                 )
 | |
|                             )
 | |
|             except Exception as err:
 | |
|                 # we don't really expect to get here, because any errors should already
 | |
|                 # have been caught and logged. But if we do, let's log the error and make
 | |
|                 # sure that all of the deferreds are resolved.
 | |
|                 logger.error("Unexpected error in _get_server_verify_keys: %s", err)
 | |
|                 with PreserveLoggingContext():
 | |
|                     for verify_request in remaining_requests:
 | |
|                         if not verify_request.key_ready.called:
 | |
|                             verify_request.key_ready.errback(err)
 | |
| 
 | |
|         run_in_background(do_iterations)
 | |
| 
 | |
|     async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
 | |
|         """Use a key fetcher to attempt to satisfy some key requests
 | |
| 
 | |
|         Args:
 | |
|             fetcher (KeyFetcher): fetcher to use to fetch the keys
 | |
|             remaining_requests (set[VerifyJsonRequest]): outstanding key requests.
 | |
|                 Any successfully-completed requests will be removed from the list.
 | |
|         """
 | |
|         # dict[str, dict[str, int]]: keys to fetch.
 | |
|         # server_name -> key_id -> min_valid_ts
 | |
|         missing_keys = defaultdict(dict)
 | |
| 
 | |
|         for verify_request in remaining_requests:
 | |
|             # any completed requests should already have been removed
 | |
|             assert not verify_request.key_ready.called
 | |
|             keys_for_server = missing_keys[verify_request.server_name]
 | |
| 
 | |
|             for key_id in verify_request.key_ids:
 | |
|                 # If we have several requests for the same key, then we only need to
 | |
|                 # request that key once, but we should do so with the greatest
 | |
|                 # min_valid_until_ts of the requests, so that we can satisfy all of
 | |
|                 # the requests.
 | |
|                 keys_for_server[key_id] = max(
 | |
|                     keys_for_server.get(key_id, -1),
 | |
|                     verify_request.minimum_valid_until_ts,
 | |
|                 )
 | |
| 
 | |
|         results = await fetcher.get_keys(missing_keys)
 | |
| 
 | |
|         completed = []
 | |
|         for verify_request in remaining_requests:
 | |
|             server_name = verify_request.server_name
 | |
| 
 | |
|             # see if any of the keys we got this time are sufficient to
 | |
|             # complete this VerifyJsonRequest.
 | |
|             result_keys = results.get(server_name, {})
 | |
|             for key_id in verify_request.key_ids:
 | |
|                 fetch_key_result = result_keys.get(key_id)
 | |
|                 if not fetch_key_result:
 | |
|                     # we didn't get a result for this key
 | |
|                     continue
 | |
| 
 | |
|                 if (
 | |
|                     fetch_key_result.valid_until_ts
 | |
|                     < verify_request.minimum_valid_until_ts
 | |
|                 ):
 | |
|                     # key was not valid at this point
 | |
|                     continue
 | |
| 
 | |
|                 with PreserveLoggingContext():
 | |
|                     verify_request.key_ready.callback(
 | |
|                         (server_name, key_id, fetch_key_result.verify_key)
 | |
|                     )
 | |
|                 completed.append(verify_request)
 | |
|                 break
 | |
| 
 | |
|         remaining_requests.difference_update(completed)
 | |
| 
 | |
| 
 | |
| class KeyFetcher(object):
 | |
|     async def get_keys(self, keys_to_fetch):
 | |
|         """
 | |
|         Args:
 | |
|             keys_to_fetch (dict[str, dict[str, int]]):
 | |
|                 the keys to be fetched. server_name -> key_id -> min_valid_ts
 | |
| 
 | |
|         Returns:
 | |
|             Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
 | |
|                 map from server_name -> key_id -> FetchKeyResult
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
| 
 | |
| class StoreKeyFetcher(KeyFetcher):
 | |
|     """KeyFetcher impl which fetches keys from our data store"""
 | |
| 
 | |
|     def __init__(self, hs):
 | |
|         self.store = hs.get_datastore()
 | |
| 
 | |
|     async def get_keys(self, keys_to_fetch):
 | |
|         """see KeyFetcher.get_keys"""
 | |
| 
 | |
|         keys_to_fetch = (
 | |
|             (server_name, key_id)
 | |
|             for server_name, keys_for_server in keys_to_fetch.items()
 | |
|             for key_id in keys_for_server.keys()
 | |
|         )
 | |
| 
 | |
|         res = await self.store.get_server_verify_keys(keys_to_fetch)
 | |
|         keys = {}
 | |
|         for (server_name, key_id), key in res.items():
 | |
|             keys.setdefault(server_name, {})[key_id] = key
 | |
|         return keys
 | |
| 
 | |
| 
 | |
| class BaseV2KeyFetcher(object):
 | |
|     def __init__(self, hs):
 | |
|         self.store = hs.get_datastore()
 | |
|         self.config = hs.get_config()
 | |
| 
 | |
|     async def process_v2_response(self, from_server, response_json, time_added_ms):
 | |
|         """Parse a 'Server Keys' structure from the result of a /key request
 | |
| 
 | |
|         This is used to parse either the entirety of the response from
 | |
|         GET /_matrix/key/v2/server, or a single entry from the list returned by
 | |
|         POST /_matrix/key/v2/query.
 | |
| 
 | |
|         Checks that each signature in the response that claims to come from the origin
 | |
|         server is valid, and that there is at least one such signature.
 | |
| 
 | |
|         Stores the json in server_keys_json so that it can be used for future responses
 | |
|         to /_matrix/key/v2/query.
 | |
| 
 | |
|         Args:
 | |
|             from_server (str): the name of the server producing this result: either
 | |
|                 the origin server for a /_matrix/key/v2/server request, or the notary
 | |
|                 for a /_matrix/key/v2/query.
 | |
| 
 | |
|             response_json (dict): the json-decoded Server Keys response object
 | |
| 
 | |
|             time_added_ms (int): the timestamp to record in server_keys_json
 | |
| 
 | |
|         Returns:
 | |
|             Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
 | |
|         """
 | |
|         ts_valid_until_ms = response_json["valid_until_ts"]
 | |
| 
 | |
|         # start by extracting the keys from the response, since they may be required
 | |
|         # to validate the signature on the response.
 | |
|         verify_keys = {}
 | |
|         for key_id, key_data in response_json["verify_keys"].items():
 | |
|             if is_signing_algorithm_supported(key_id):
 | |
|                 key_base64 = key_data["key"]
 | |
|                 key_bytes = decode_base64(key_base64)
 | |
|                 verify_key = decode_verify_key_bytes(key_id, key_bytes)
 | |
|                 verify_keys[key_id] = FetchKeyResult(
 | |
|                     verify_key=verify_key, valid_until_ts=ts_valid_until_ms
 | |
|                 )
 | |
| 
 | |
|         server_name = response_json["server_name"]
 | |
|         verified = False
 | |
|         for key_id in response_json["signatures"].get(server_name, {}):
 | |
|             key = verify_keys.get(key_id)
 | |
|             if not key:
 | |
|                 # the key may not be present in verify_keys if:
 | |
|                 #  * we got the key from the notary server, and:
 | |
|                 #  * the key belongs to the notary server, and:
 | |
|                 #  * the notary server is using a different key to sign notary
 | |
|                 #    responses.
 | |
|                 continue
 | |
| 
 | |
|             verify_signed_json(response_json, server_name, key.verify_key)
 | |
|             verified = True
 | |
|             break
 | |
| 
 | |
|         if not verified:
 | |
|             raise KeyLookupError(
 | |
|                 "Key response for %s is not signed by the origin server"
 | |
|                 % (server_name,)
 | |
|             )
 | |
| 
 | |
|         for key_id, key_data in response_json["old_verify_keys"].items():
 | |
|             if is_signing_algorithm_supported(key_id):
 | |
|                 key_base64 = key_data["key"]
 | |
|                 key_bytes = decode_base64(key_base64)
 | |
|                 verify_key = decode_verify_key_bytes(key_id, key_bytes)
 | |
|                 verify_keys[key_id] = FetchKeyResult(
 | |
|                     verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
 | |
|                 )
 | |
| 
 | |
|         key_json_bytes = encode_canonical_json(response_json)
 | |
| 
 | |
|         await make_deferred_yieldable(
 | |
|             defer.gatherResults(
 | |
|                 [
 | |
|                     run_in_background(
 | |
|                         self.store.store_server_keys_json,
 | |
|                         server_name=server_name,
 | |
|                         key_id=key_id,
 | |
|                         from_server=from_server,
 | |
|                         ts_now_ms=time_added_ms,
 | |
|                         ts_expires_ms=ts_valid_until_ms,
 | |
|                         key_json_bytes=key_json_bytes,
 | |
|                     )
 | |
|                     for key_id in verify_keys
 | |
|                 ],
 | |
|                 consumeErrors=True,
 | |
|             ).addErrback(unwrapFirstError)
 | |
|         )
 | |
| 
 | |
|         return verify_keys
 | |
| 
 | |
| 
 | |
| class PerspectivesKeyFetcher(BaseV2KeyFetcher):
 | |
|     """KeyFetcher impl which fetches keys from the "perspectives" servers"""
 | |
| 
 | |
|     def __init__(self, hs):
 | |
|         super(PerspectivesKeyFetcher, self).__init__(hs)
 | |
|         self.clock = hs.get_clock()
 | |
|         self.client = hs.get_http_client()
 | |
|         self.key_servers = self.config.key_servers
 | |
| 
 | |
|     async def get_keys(self, keys_to_fetch):
 | |
|         """see KeyFetcher.get_keys"""
 | |
| 
 | |
|         async def get_key(key_server):
 | |
|             try:
 | |
|                 result = await self.get_server_verify_key_v2_indirect(
 | |
|                     keys_to_fetch, key_server
 | |
|                 )
 | |
|                 return result
 | |
|             except KeyLookupError as e:
 | |
|                 logger.warning(
 | |
|                     "Key lookup failed from %r: %s", key_server.server_name, e
 | |
|                 )
 | |
|             except Exception as e:
 | |
|                 logger.exception(
 | |
|                     "Unable to get key from %r: %s %s",
 | |
|                     key_server.server_name,
 | |
|                     type(e).__name__,
 | |
|                     str(e),
 | |
|                 )
 | |
| 
 | |
|             return {}
 | |
| 
 | |
|         results = await make_deferred_yieldable(
 | |
|             defer.gatherResults(
 | |
|                 [run_in_background(get_key, server) for server in self.key_servers],
 | |
|                 consumeErrors=True,
 | |
|             ).addErrback(unwrapFirstError)
 | |
|         )
 | |
| 
 | |
|         union_of_keys = {}
 | |
|         for result in results:
 | |
|             for server_name, keys in result.items():
 | |
|                 union_of_keys.setdefault(server_name, {}).update(keys)
 | |
| 
 | |
|         return union_of_keys
 | |
| 
 | |
|     async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
 | |
|         """
 | |
|         Args:
 | |
|             keys_to_fetch (dict[str, dict[str, int]]):
 | |
|                 the keys to be fetched. server_name -> key_id -> min_valid_ts
 | |
| 
 | |
|             key_server (synapse.config.key.TrustedKeyServer): notary server to query for
 | |
|                 the keys
 | |
| 
 | |
|         Returns:
 | |
|             dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
 | |
|                 from server_name -> key_id -> FetchKeyResult
 | |
| 
 | |
|         Raises:
 | |
|             KeyLookupError if there was an error processing the entire response from
 | |
|                 the server
 | |
|         """
 | |
|         perspective_name = key_server.server_name
 | |
|         logger.info(
 | |
|             "Requesting keys %s from notary server %s",
 | |
|             keys_to_fetch.items(),
 | |
|             perspective_name,
 | |
|         )
 | |
| 
 | |
|         try:
 | |
|             query_response = await self.client.post_json(
 | |
|                 destination=perspective_name,
 | |
|                 path="/_matrix/key/v2/query",
 | |
|                 data={
 | |
|                     "server_keys": {
 | |
|                         server_name: {
 | |
|                             key_id: {"minimum_valid_until_ts": min_valid_ts}
 | |
|                             for key_id, min_valid_ts in server_keys.items()
 | |
|                         }
 | |
|                         for server_name, server_keys in keys_to_fetch.items()
 | |
|                     }
 | |
|                 },
 | |
|             )
 | |
|         except (NotRetryingDestination, RequestSendFailed) as e:
 | |
|             # these both have str() representations which we can't really improve upon
 | |
|             raise KeyLookupError(str(e))
 | |
|         except HttpResponseException as e:
 | |
|             raise KeyLookupError("Remote server returned an error: %s" % (e,))
 | |
| 
 | |
|         keys = {}
 | |
|         added_keys = []
 | |
| 
 | |
|         time_now_ms = self.clock.time_msec()
 | |
| 
 | |
|         for response in query_response["server_keys"]:
 | |
|             # do this first, so that we can give useful errors thereafter
 | |
|             server_name = response.get("server_name")
 | |
|             if not isinstance(server_name, str):
 | |
|                 raise KeyLookupError(
 | |
|                     "Malformed response from key notary server %s: invalid server_name"
 | |
|                     % (perspective_name,)
 | |
|                 )
 | |
| 
 | |
|             try:
 | |
|                 self._validate_perspectives_response(key_server, response)
 | |
| 
 | |
|                 processed_response = await self.process_v2_response(
 | |
|                     perspective_name, response, time_added_ms=time_now_ms
 | |
|                 )
 | |
|             except KeyLookupError as e:
 | |
|                 logger.warning(
 | |
|                     "Error processing response from key notary server %s for origin "
 | |
|                     "server %s: %s",
 | |
|                     perspective_name,
 | |
|                     server_name,
 | |
|                     e,
 | |
|                 )
 | |
|                 # we continue to process the rest of the response
 | |
|                 continue
 | |
| 
 | |
|             added_keys.extend(
 | |
|                 (server_name, key_id, key) for key_id, key in processed_response.items()
 | |
|             )
 | |
|             keys.setdefault(server_name, {}).update(processed_response)
 | |
| 
 | |
|         await self.store.store_server_verify_keys(
 | |
|             perspective_name, time_now_ms, added_keys
 | |
|         )
 | |
| 
 | |
|         return keys
 | |
| 
 | |
|     def _validate_perspectives_response(self, key_server, response):
 | |
|         """Optionally check the signature on the result of a /key/query request
 | |
| 
 | |
|         Args:
 | |
|             key_server (synapse.config.key.TrustedKeyServer): the notary server that
 | |
|                 produced this result
 | |
| 
 | |
|             response (dict): the json-decoded Server Keys response object
 | |
|         """
 | |
|         perspective_name = key_server.server_name
 | |
|         perspective_keys = key_server.verify_keys
 | |
| 
 | |
|         if perspective_keys is None:
 | |
|             # signature checking is disabled on this server
 | |
|             return
 | |
| 
 | |
|         if (
 | |
|             "signatures" not in response
 | |
|             or perspective_name not in response["signatures"]
 | |
|         ):
 | |
|             raise KeyLookupError("Response not signed by the notary server")
 | |
| 
 | |
|         verified = False
 | |
|         for key_id in response["signatures"][perspective_name]:
 | |
|             if key_id in perspective_keys:
 | |
|                 verify_signed_json(response, perspective_name, perspective_keys[key_id])
 | |
|                 verified = True
 | |
| 
 | |
|         if not verified:
 | |
|             raise KeyLookupError(
 | |
|                 "Response not signed with a known key: signed with: %r, known keys: %r"
 | |
|                 % (
 | |
|                     list(response["signatures"][perspective_name].keys()),
 | |
|                     list(perspective_keys.keys()),
 | |
|                 )
 | |
|             )
 | |
| 
 | |
| 
 | |
| class ServerKeyFetcher(BaseV2KeyFetcher):
 | |
|     """KeyFetcher impl which fetches keys from the origin servers"""
 | |
| 
 | |
|     def __init__(self, hs):
 | |
|         super(ServerKeyFetcher, self).__init__(hs)
 | |
|         self.clock = hs.get_clock()
 | |
|         self.client = hs.get_http_client()
 | |
| 
 | |
|     async def get_keys(self, keys_to_fetch):
 | |
|         """
 | |
|         Args:
 | |
|             keys_to_fetch (dict[str, iterable[str]]):
 | |
|                 the keys to be fetched. server_name -> key_ids
 | |
| 
 | |
|         Returns:
 | |
|             dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
 | |
|                 map from server_name -> key_id -> FetchKeyResult
 | |
|         """
 | |
| 
 | |
|         results = {}
 | |
| 
 | |
|         async def get_key(key_to_fetch_item):
 | |
|             server_name, key_ids = key_to_fetch_item
 | |
|             try:
 | |
|                 keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
 | |
|                 results[server_name] = keys
 | |
|             except KeyLookupError as e:
 | |
|                 logger.warning(
 | |
|                     "Error looking up keys %s from %s: %s", key_ids, server_name, e
 | |
|                 )
 | |
|             except Exception:
 | |
|                 logger.exception("Error getting keys %s from %s", key_ids, server_name)
 | |
| 
 | |
|         await yieldable_gather_results(get_key, keys_to_fetch.items())
 | |
|         return results
 | |
| 
 | |
|     async def get_server_verify_key_v2_direct(self, server_name, key_ids):
 | |
|         """
 | |
| 
 | |
|         Args:
 | |
|             server_name (str):
 | |
|             key_ids (iterable[str]):
 | |
| 
 | |
|         Returns:
 | |
|             dict[str, FetchKeyResult]: map from key ID to lookup result
 | |
| 
 | |
|         Raises:
 | |
|             KeyLookupError if there was a problem making the lookup
 | |
|         """
 | |
|         keys = {}  # type: dict[str, FetchKeyResult]
 | |
| 
 | |
|         for requested_key_id in key_ids:
 | |
|             # we may have found this key as a side-effect of asking for another.
 | |
|             if requested_key_id in keys:
 | |
|                 continue
 | |
| 
 | |
|             time_now_ms = self.clock.time_msec()
 | |
|             try:
 | |
|                 response = await self.client.get_json(
 | |
|                     destination=server_name,
 | |
|                     path="/_matrix/key/v2/server/"
 | |
|                     + urllib.parse.quote(requested_key_id),
 | |
|                     ignore_backoff=True,
 | |
|                     # we only give the remote server 10s to respond. It should be an
 | |
|                     # easy request to handle, so if it doesn't reply within 10s, it's
 | |
|                     # probably not going to.
 | |
|                     #
 | |
|                     # Furthermore, when we are acting as a notary server, we cannot
 | |
|                     # wait all day for all of the origin servers, as the requesting
 | |
|                     # server will otherwise time out before we can respond.
 | |
|                     #
 | |
|                     # (Note that get_json may make 4 attempts, so this can still take
 | |
|                     # almost 45 seconds to fetch the headers, plus up to another 60s to
 | |
|                     # read the response).
 | |
|                     timeout=10000,
 | |
|                 )
 | |
|             except (NotRetryingDestination, RequestSendFailed) as e:
 | |
|                 # these both have str() representations which we can't really improve
 | |
|                 # upon
 | |
|                 raise KeyLookupError(str(e))
 | |
|             except HttpResponseException as e:
 | |
|                 raise KeyLookupError("Remote server returned an error: %s" % (e,))
 | |
| 
 | |
|             if response["server_name"] != server_name:
 | |
|                 raise KeyLookupError(
 | |
|                     "Expected a response for server %r not %r"
 | |
|                     % (server_name, response["server_name"])
 | |
|                 )
 | |
| 
 | |
|             response_keys = await self.process_v2_response(
 | |
|                 from_server=server_name,
 | |
|                 response_json=response,
 | |
|                 time_added_ms=time_now_ms,
 | |
|             )
 | |
|             await self.store.store_server_verify_keys(
 | |
|                 server_name,
 | |
|                 time_now_ms,
 | |
|                 ((server_name, key_id, key) for key_id, key in response_keys.items()),
 | |
|             )
 | |
|             keys.update(response_keys)
 | |
| 
 | |
|         return keys
 | |
| 
 | |
| 
 | |
| async def _handle_key_deferred(verify_request) -> None:
 | |
|     """Waits for the key to become available, and then performs a verification
 | |
| 
 | |
|     Args:
 | |
|         verify_request (VerifyJsonRequest):
 | |
| 
 | |
|     Raises:
 | |
|         SynapseError if there was a problem performing the verification
 | |
|     """
 | |
|     server_name = verify_request.server_name
 | |
|     with PreserveLoggingContext():
 | |
|         _, key_id, verify_key = await verify_request.key_ready
 | |
| 
 | |
|     json_object = verify_request.json_object
 | |
| 
 | |
|     try:
 | |
|         verify_signed_json(json_object, server_name, verify_key)
 | |
|     except SignatureVerifyException as e:
 | |
|         logger.debug(
 | |
|             "Error verifying signature for %s:%s:%s with key %s: %s",
 | |
|             server_name,
 | |
|             verify_key.alg,
 | |
|             verify_key.version,
 | |
|             encode_verify_key_base64(verify_key),
 | |
|             str(e),
 | |
|         )
 | |
|         raise SynapseError(
 | |
|             401,
 | |
|             "Invalid signature for server %s with key %s:%s: %s"
 | |
|             % (server_name, verify_key.alg, verify_key.version, str(e)),
 | |
|             Codes.UNAUTHORIZED,
 | |
|         )
 |