903 lines
		
	
	
		
			33 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			903 lines
		
	
	
		
			33 KiB
		
	
	
	
		
			Python
		
	
	
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
 | 
						|
import abc
 | 
						|
import logging
 | 
						|
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
 | 
						|
 | 
						|
import attr
 | 
						|
from signedjson.key import (
 | 
						|
    decode_verify_key_bytes,
 | 
						|
    encode_verify_key_base64,
 | 
						|
    get_verify_key,
 | 
						|
    is_signing_algorithm_supported,
 | 
						|
)
 | 
						|
from signedjson.sign import SignatureVerifyException, signature_ids, verify_signed_json
 | 
						|
from signedjson.types import VerifyKey
 | 
						|
from unpaddedbase64 import decode_base64
 | 
						|
 | 
						|
from twisted.internet import defer
 | 
						|
 | 
						|
from synapse.api.errors import (
 | 
						|
    Codes,
 | 
						|
    HttpResponseException,
 | 
						|
    RequestSendFailed,
 | 
						|
    SynapseError,
 | 
						|
)
 | 
						|
from synapse.config.key import TrustedKeyServer
 | 
						|
from synapse.events import EventBase
 | 
						|
from synapse.events.utils import prune_event_dict
 | 
						|
from synapse.logging.context import make_deferred_yieldable, run_in_background
 | 
						|
from synapse.storage.keys import FetchKeyResult
 | 
						|
from synapse.types import JsonDict
 | 
						|
from synapse.util import unwrapFirstError
 | 
						|
from synapse.util.async_helpers import yieldable_gather_results
 | 
						|
from synapse.util.batching_queue import BatchingQueue
 | 
						|
from synapse.util.retryutils import NotRetryingDestination
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    from synapse.server import HomeServer
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
 | 
						|
@attr.s(slots=True, frozen=True, cmp=False, auto_attribs=True)
 | 
						|
class VerifyJsonRequest:
 | 
						|
    """
 | 
						|
    A request to verify a JSON object.
 | 
						|
 | 
						|
    Attributes:
 | 
						|
        server_name: The name of the server to verify against.
 | 
						|
 | 
						|
        get_json_object: A callback to fetch the JSON object to verify.
 | 
						|
            A callback is used to allow deferring the creation of the JSON
 | 
						|
            object to verify until needed, e.g. for events we can defer
 | 
						|
            creating the redacted copy. This reduces the memory usage when
 | 
						|
            there are large numbers of in flight requests.
 | 
						|
 | 
						|
        minimum_valid_until_ts: time at which we require the signing key to
 | 
						|
            be valid. (0 implies we don't care)
 | 
						|
 | 
						|
        key_ids: The set of key_ids to that could be used to verify the JSON object
 | 
						|
    """
 | 
						|
 | 
						|
    server_name: str
 | 
						|
    get_json_object: Callable[[], JsonDict]
 | 
						|
    minimum_valid_until_ts: int
 | 
						|
    key_ids: List[str]
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def from_json_object(
 | 
						|
        server_name: str,
 | 
						|
        json_object: JsonDict,
 | 
						|
        minimum_valid_until_ms: int,
 | 
						|
    ) -> "VerifyJsonRequest":
 | 
						|
        """Create a VerifyJsonRequest to verify all signatures on a signed JSON
 | 
						|
        object for the given server.
 | 
						|
        """
 | 
						|
        key_ids = signature_ids(json_object, server_name)
 | 
						|
        return VerifyJsonRequest(
 | 
						|
            server_name,
 | 
						|
            lambda: json_object,
 | 
						|
            minimum_valid_until_ms,
 | 
						|
            key_ids=key_ids,
 | 
						|
        )
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def from_event(
 | 
						|
        server_name: str,
 | 
						|
        event: EventBase,
 | 
						|
        minimum_valid_until_ms: int,
 | 
						|
    ) -> "VerifyJsonRequest":
 | 
						|
        """Create a VerifyJsonRequest to verify all signatures on an event
 | 
						|
        object for the given server.
 | 
						|
        """
 | 
						|
        key_ids = list(event.signatures.get(server_name, []))
 | 
						|
        return VerifyJsonRequest(
 | 
						|
            server_name,
 | 
						|
            # We defer creating the redacted json object, as it uses a lot more
 | 
						|
            # memory than the Event object itself.
 | 
						|
            lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
 | 
						|
            minimum_valid_until_ms,
 | 
						|
            key_ids=key_ids,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
class KeyLookupError(ValueError):
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
@attr.s(slots=True, frozen=True, auto_attribs=True)
 | 
						|
class _FetchKeyRequest:
 | 
						|
    """A request for keys for a given server.
 | 
						|
 | 
						|
    We will continue to try and fetch until we have all the keys listed under
 | 
						|
    `key_ids` (with an appropriate `valid_until_ts` property) or we run out of
 | 
						|
    places to fetch keys from.
 | 
						|
 | 
						|
    Attributes:
 | 
						|
        server_name: The name of the server that owns the keys.
 | 
						|
        minimum_valid_until_ts: The timestamp which the keys must be valid until.
 | 
						|
        key_ids: The IDs of the keys to attempt to fetch
 | 
						|
    """
 | 
						|
 | 
						|
    server_name: str
 | 
						|
    minimum_valid_until_ts: int
 | 
						|
    key_ids: List[str]
 | 
						|
 | 
						|
 | 
						|
class Keyring:
 | 
						|
    """Handles verifying signed JSON objects and fetching the keys needed to do
 | 
						|
    so.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
 | 
						|
    ):
 | 
						|
        if key_fetchers is None:
 | 
						|
            # Always fetch keys from the database.
 | 
						|
            mutable_key_fetchers: List[KeyFetcher] = [StoreKeyFetcher(hs)]
 | 
						|
            # Fetch keys from configured trusted key servers, if any exist.
 | 
						|
            key_servers = hs.config.key.key_servers
 | 
						|
            if key_servers:
 | 
						|
                mutable_key_fetchers.append(PerspectivesKeyFetcher(hs))
 | 
						|
            # Finally, fetch keys from the origin server directly.
 | 
						|
            mutable_key_fetchers.append(ServerKeyFetcher(hs))
 | 
						|
 | 
						|
            self._key_fetchers: Iterable[KeyFetcher] = tuple(mutable_key_fetchers)
 | 
						|
        else:
 | 
						|
            self._key_fetchers = key_fetchers
 | 
						|
 | 
						|
        self._fetch_keys_queue: BatchingQueue[
 | 
						|
            _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]
 | 
						|
        ] = BatchingQueue(
 | 
						|
            "keyring_server",
 | 
						|
            clock=hs.get_clock(),
 | 
						|
            # The method called to fetch each key
 | 
						|
            process_batch_callback=self._inner_fetch_key_requests,
 | 
						|
        )
 | 
						|
 | 
						|
        self._is_mine_server_name = hs.is_mine_server_name
 | 
						|
 | 
						|
        # build a FetchKeyResult for each of our own keys, to shortcircuit the
 | 
						|
        # fetcher.
 | 
						|
        self._local_verify_keys: Dict[str, FetchKeyResult] = {}
 | 
						|
        for key_id, key in hs.config.key.old_signing_keys.items():
 | 
						|
            self._local_verify_keys[key_id] = FetchKeyResult(
 | 
						|
                verify_key=key, valid_until_ts=key.expired
 | 
						|
            )
 | 
						|
 | 
						|
        vk = get_verify_key(hs.signing_key)
 | 
						|
        self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult(
 | 
						|
            verify_key=vk,
 | 
						|
            valid_until_ts=2**63,  # fake future timestamp
 | 
						|
        )
 | 
						|
 | 
						|
    async def verify_json_for_server(
 | 
						|
        self,
 | 
						|
        server_name: str,
 | 
						|
        json_object: JsonDict,
 | 
						|
        validity_time: int,
 | 
						|
    ) -> None:
 | 
						|
        """Verify that a JSON object has been signed by a given server
 | 
						|
 | 
						|
        Completes if the the object was correctly signed, otherwise raises.
 | 
						|
 | 
						|
        Args:
 | 
						|
            server_name: name of the server which must have signed this object
 | 
						|
 | 
						|
            json_object: object to be checked
 | 
						|
 | 
						|
            validity_time: timestamp at which we require the signing key to
 | 
						|
                be valid. (0 implies we don't care)
 | 
						|
        """
 | 
						|
 | 
						|
        request = VerifyJsonRequest.from_json_object(
 | 
						|
            server_name,
 | 
						|
            json_object,
 | 
						|
            validity_time,
 | 
						|
        )
 | 
						|
        return await self.process_request(request)
 | 
						|
 | 
						|
    def verify_json_objects_for_server(
 | 
						|
        self, server_and_json: Iterable[Tuple[str, dict, int]]
 | 
						|
    ) -> List["defer.Deferred[None]"]:
 | 
						|
        """Bulk verifies signatures of json objects, bulk fetching keys as
 | 
						|
        necessary.
 | 
						|
 | 
						|
        Args:
 | 
						|
            server_and_json:
 | 
						|
                Iterable of (server_name, json_object, validity_time)
 | 
						|
                tuples.
 | 
						|
 | 
						|
                validity_time is a timestamp at which the signing key must be
 | 
						|
                valid.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            For each input triplet, a deferred indicating success or failure to
 | 
						|
            verify each json object's signature for the given server_name. The
 | 
						|
            deferreds run their callbacks in the sentinel logcontext.
 | 
						|
        """
 | 
						|
        return [
 | 
						|
            run_in_background(
 | 
						|
                self.process_request,
 | 
						|
                VerifyJsonRequest.from_json_object(
 | 
						|
                    server_name,
 | 
						|
                    json_object,
 | 
						|
                    validity_time,
 | 
						|
                ),
 | 
						|
            )
 | 
						|
            for server_name, json_object, validity_time in server_and_json
 | 
						|
        ]
 | 
						|
 | 
						|
    async def verify_event_for_server(
 | 
						|
        self,
 | 
						|
        server_name: str,
 | 
						|
        event: EventBase,
 | 
						|
        validity_time: int,
 | 
						|
    ) -> None:
 | 
						|
        await self.process_request(
 | 
						|
            VerifyJsonRequest.from_event(
 | 
						|
                server_name,
 | 
						|
                event,
 | 
						|
                validity_time,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    async def process_request(self, verify_request: VerifyJsonRequest) -> None:
 | 
						|
        """Processes the `VerifyJsonRequest`. Raises if the object is not signed
 | 
						|
        by the server, the signatures don't match or we failed to fetch the
 | 
						|
        necessary keys.
 | 
						|
        """
 | 
						|
 | 
						|
        if not verify_request.key_ids:
 | 
						|
            raise SynapseError(
 | 
						|
                400,
 | 
						|
                f"Not signed by {verify_request.server_name}",
 | 
						|
                Codes.UNAUTHORIZED,
 | 
						|
            )
 | 
						|
 | 
						|
        found_keys: Dict[str, FetchKeyResult] = {}
 | 
						|
 | 
						|
        # If we are the originating server, short-circuit the key-fetch for any keys
 | 
						|
        # we already have
 | 
						|
        if self._is_mine_server_name(verify_request.server_name):
 | 
						|
            for key_id in verify_request.key_ids:
 | 
						|
                if key_id in self._local_verify_keys:
 | 
						|
                    found_keys[key_id] = self._local_verify_keys[key_id]
 | 
						|
 | 
						|
        key_ids_to_find = set(verify_request.key_ids) - found_keys.keys()
 | 
						|
        if key_ids_to_find:
 | 
						|
            # Add the keys we need to verify to the queue for retrieval. We queue
 | 
						|
            # up requests for the same server so we don't end up with many in flight
 | 
						|
            # requests for the same keys.
 | 
						|
            key_request = _FetchKeyRequest(
 | 
						|
                server_name=verify_request.server_name,
 | 
						|
                minimum_valid_until_ts=verify_request.minimum_valid_until_ts,
 | 
						|
                key_ids=list(key_ids_to_find),
 | 
						|
            )
 | 
						|
            found_keys_by_server = await self._fetch_keys_queue.add_to_queue(
 | 
						|
                key_request, key=verify_request.server_name
 | 
						|
            )
 | 
						|
 | 
						|
            # Since we batch up requests the returned set of keys may contain keys
 | 
						|
            # from other servers, so we pull out only the ones we care about.
 | 
						|
            found_keys.update(found_keys_by_server.get(verify_request.server_name, {}))
 | 
						|
 | 
						|
        # Verify each signature we got valid keys for, raising if we can't
 | 
						|
        # verify any of them.
 | 
						|
        verified = False
 | 
						|
        for key_id in verify_request.key_ids:
 | 
						|
            key_result = found_keys.get(key_id)
 | 
						|
            if not key_result:
 | 
						|
                continue
 | 
						|
 | 
						|
            if key_result.valid_until_ts < verify_request.minimum_valid_until_ts:
 | 
						|
                continue
 | 
						|
 | 
						|
            await self._process_json(key_result.verify_key, verify_request)
 | 
						|
            verified = True
 | 
						|
 | 
						|
        if not verified:
 | 
						|
            raise SynapseError(
 | 
						|
                401,
 | 
						|
                f"Failed to find any key to satisfy: {key_request}",
 | 
						|
                Codes.UNAUTHORIZED,
 | 
						|
            )
 | 
						|
 | 
						|
    async def _process_json(
 | 
						|
        self, verify_key: VerifyKey, verify_request: VerifyJsonRequest
 | 
						|
    ) -> None:
 | 
						|
        """Processes the `VerifyJsonRequest`. Raises if the signature can't be
 | 
						|
        verified.
 | 
						|
        """
 | 
						|
        try:
 | 
						|
            verify_signed_json(
 | 
						|
                verify_request.get_json_object(),
 | 
						|
                verify_request.server_name,
 | 
						|
                verify_key,
 | 
						|
            )
 | 
						|
        except SignatureVerifyException as e:
 | 
						|
            logger.debug(
 | 
						|
                "Error verifying signature for %s:%s:%s with key %s: %s",
 | 
						|
                verify_request.server_name,
 | 
						|
                verify_key.alg,
 | 
						|
                verify_key.version,
 | 
						|
                encode_verify_key_base64(verify_key),
 | 
						|
                str(e),
 | 
						|
            )
 | 
						|
            raise SynapseError(
 | 
						|
                401,
 | 
						|
                "Invalid signature for server %s with key %s:%s: %s"
 | 
						|
                % (
 | 
						|
                    verify_request.server_name,
 | 
						|
                    verify_key.alg,
 | 
						|
                    verify_key.version,
 | 
						|
                    str(e),
 | 
						|
                ),
 | 
						|
                Codes.UNAUTHORIZED,
 | 
						|
            )
 | 
						|
 | 
						|
    async def _inner_fetch_key_requests(
 | 
						|
        self, requests: List[_FetchKeyRequest]
 | 
						|
    ) -> Dict[str, Dict[str, FetchKeyResult]]:
 | 
						|
        """Processing function for the queue of `_FetchKeyRequest`.
 | 
						|
 | 
						|
        Takes a list of key fetch requests, de-duplicates them and then carries out
 | 
						|
        each request by invoking self._inner_fetch_key_request.
 | 
						|
 | 
						|
        Args:
 | 
						|
            requests: A list of requests for homeserver verify keys.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            {server name: {key id: fetch key result}}
 | 
						|
        """
 | 
						|
 | 
						|
        logger.debug("Starting fetch for %s", requests)
 | 
						|
 | 
						|
        # First we need to deduplicate requests for the same key. We do this by
 | 
						|
        # taking the *maximum* requested `minimum_valid_until_ts` for each pair
 | 
						|
        # of server name/key ID.
 | 
						|
        server_to_key_to_ts: Dict[str, Dict[str, int]] = {}
 | 
						|
        for request in requests:
 | 
						|
            by_server = server_to_key_to_ts.setdefault(request.server_name, {})
 | 
						|
            for key_id in request.key_ids:
 | 
						|
                existing_ts = by_server.get(key_id, 0)
 | 
						|
                by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts)
 | 
						|
 | 
						|
        deduped_requests = [
 | 
						|
            _FetchKeyRequest(server_name, minimum_valid_ts, [key_id])
 | 
						|
            for server_name, by_server in server_to_key_to_ts.items()
 | 
						|
            for key_id, minimum_valid_ts in by_server.items()
 | 
						|
        ]
 | 
						|
 | 
						|
        logger.debug("Deduplicated key requests to %s", deduped_requests)
 | 
						|
 | 
						|
        # For each key we call `_inner_verify_request` which will handle
 | 
						|
        # fetching each key. Note these shouldn't throw if we fail to contact
 | 
						|
        # other servers etc.
 | 
						|
        results_per_request = await yieldable_gather_results(
 | 
						|
            self._inner_fetch_key_request,
 | 
						|
            deduped_requests,
 | 
						|
        )
 | 
						|
 | 
						|
        # We now convert the returned list of results into a map from server
 | 
						|
        # name to key ID to FetchKeyResult, to return.
 | 
						|
        to_return: Dict[str, Dict[str, FetchKeyResult]] = {}
 | 
						|
        for request, results in zip(deduped_requests, results_per_request):
 | 
						|
            to_return_by_server = to_return.setdefault(request.server_name, {})
 | 
						|
            for key_id, key_result in results.items():
 | 
						|
                existing = to_return_by_server.get(key_id)
 | 
						|
                if not existing or existing.valid_until_ts < key_result.valid_until_ts:
 | 
						|
                    to_return_by_server[key_id] = key_result
 | 
						|
 | 
						|
        return to_return
 | 
						|
 | 
						|
    async def _inner_fetch_key_request(
 | 
						|
        self, verify_request: _FetchKeyRequest
 | 
						|
    ) -> Dict[str, FetchKeyResult]:
 | 
						|
        """Attempt to fetch the given key by calling each key fetcher one by one.
 | 
						|
 | 
						|
        If a key is found, check whether its `valid_until_ts` attribute satisfies the
 | 
						|
        `minimum_valid_until_ts` attribute of the `verify_request`. If it does, we
 | 
						|
        refrain from asking subsequent fetchers for that key.
 | 
						|
 | 
						|
        Even if the above check fails, we still return the found key - the caller may
 | 
						|
        still find the invalid key result useful. In this case, we continue to ask
 | 
						|
        subsequent fetchers for the invalid key, in case they return a valid result
 | 
						|
        for it. This can happen when fetching a stale key result from the database,
 | 
						|
        before querying the origin server for an up-to-date result.
 | 
						|
 | 
						|
        Args:
 | 
						|
            verify_request: The request for a verify key. Can include multiple key IDs.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            A map of {key_id: the key fetch result}.
 | 
						|
        """
 | 
						|
        logger.debug("Starting fetch for %s", verify_request)
 | 
						|
 | 
						|
        found_keys: Dict[str, FetchKeyResult] = {}
 | 
						|
        missing_key_ids = set(verify_request.key_ids)
 | 
						|
 | 
						|
        for fetcher in self._key_fetchers:
 | 
						|
            if not missing_key_ids:
 | 
						|
                break
 | 
						|
 | 
						|
            logger.debug("Getting keys from %s for %s", fetcher, verify_request)
 | 
						|
            keys = await fetcher.get_keys(
 | 
						|
                verify_request.server_name,
 | 
						|
                list(missing_key_ids),
 | 
						|
                verify_request.minimum_valid_until_ts,
 | 
						|
            )
 | 
						|
 | 
						|
            for key_id, key in keys.items():
 | 
						|
                if not key:
 | 
						|
                    continue
 | 
						|
 | 
						|
                # If we already have a result for the given key ID, we keep the
 | 
						|
                # one with the highest `valid_until_ts`.
 | 
						|
                existing_key = found_keys.get(key_id)
 | 
						|
                if existing_key and existing_key.valid_until_ts > key.valid_until_ts:
 | 
						|
                    continue
 | 
						|
 | 
						|
                # Check if this key's expiry timestamp is valid for the verify request.
 | 
						|
                if key.valid_until_ts >= verify_request.minimum_valid_until_ts:
 | 
						|
                    # Stop looking for this key from subsequent fetchers.
 | 
						|
                    missing_key_ids.discard(key_id)
 | 
						|
 | 
						|
                # We always store the returned key even if it doesn't meet the
 | 
						|
                # `minimum_valid_until_ts` requirement, as some verification
 | 
						|
                # requests may still be able to be satisfied by it.
 | 
						|
                found_keys[key_id] = key
 | 
						|
 | 
						|
        return found_keys
 | 
						|
 | 
						|
 | 
						|
class KeyFetcher(metaclass=abc.ABCMeta):
 | 
						|
    def __init__(self, hs: "HomeServer"):
 | 
						|
        self._queue = BatchingQueue(
 | 
						|
            self.__class__.__name__, hs.get_clock(), self._fetch_keys
 | 
						|
        )
 | 
						|
 | 
						|
    async def get_keys(
 | 
						|
        self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
 | 
						|
    ) -> Dict[str, FetchKeyResult]:
 | 
						|
        results = await self._queue.add_to_queue(
 | 
						|
            _FetchKeyRequest(
 | 
						|
                server_name=server_name,
 | 
						|
                key_ids=key_ids,
 | 
						|
                minimum_valid_until_ts=minimum_valid_until_ts,
 | 
						|
            )
 | 
						|
        )
 | 
						|
        return results.get(server_name, {})
 | 
						|
 | 
						|
    @abc.abstractmethod
 | 
						|
    async def _fetch_keys(
 | 
						|
        self, keys_to_fetch: List[_FetchKeyRequest]
 | 
						|
    ) -> Dict[str, Dict[str, FetchKeyResult]]:
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
class StoreKeyFetcher(KeyFetcher):
 | 
						|
    """KeyFetcher impl which fetches keys from our data store"""
 | 
						|
 | 
						|
    def __init__(self, hs: "HomeServer"):
 | 
						|
        super().__init__(hs)
 | 
						|
 | 
						|
        self.store = hs.get_datastores().main
 | 
						|
 | 
						|
    async def _fetch_keys(
 | 
						|
        self, keys_to_fetch: List[_FetchKeyRequest]
 | 
						|
    ) -> Dict[str, Dict[str, FetchKeyResult]]:
 | 
						|
        key_ids_to_fetch = (
 | 
						|
            (queue_value.server_name, key_id)
 | 
						|
            for queue_value in keys_to_fetch
 | 
						|
            for key_id in queue_value.key_ids
 | 
						|
        )
 | 
						|
 | 
						|
        res = await self.store.get_server_keys_json(key_ids_to_fetch)
 | 
						|
        keys: Dict[str, Dict[str, FetchKeyResult]] = {}
 | 
						|
        for (server_name, key_id), key in res.items():
 | 
						|
            keys.setdefault(server_name, {})[key_id] = key
 | 
						|
        return keys
 | 
						|
 | 
						|
 | 
						|
class BaseV2KeyFetcher(KeyFetcher):
 | 
						|
    def __init__(self, hs: "HomeServer"):
 | 
						|
        super().__init__(hs)
 | 
						|
 | 
						|
        self.store = hs.get_datastores().main
 | 
						|
 | 
						|
    async def process_v2_response(
 | 
						|
        self, from_server: str, response_json: JsonDict, time_added_ms: int
 | 
						|
    ) -> Dict[str, FetchKeyResult]:
 | 
						|
        """Parse a 'Server Keys' structure from the result of a /key request
 | 
						|
 | 
						|
        This is used to parse either the entirety of the response from
 | 
						|
        GET /_matrix/key/v2/server, or a single entry from the list returned by
 | 
						|
        POST /_matrix/key/v2/query.
 | 
						|
 | 
						|
        Checks that each signature in the response that claims to come from the origin
 | 
						|
        server is valid, and that there is at least one such signature.
 | 
						|
 | 
						|
        Stores the json in server_keys_json so that it can be used for future responses
 | 
						|
        to /_matrix/key/v2/query.
 | 
						|
 | 
						|
        Args:
 | 
						|
            from_server: the name of the server producing this result: either
 | 
						|
                the origin server for a /_matrix/key/v2/server request, or the notary
 | 
						|
                for a /_matrix/key/v2/query.
 | 
						|
 | 
						|
            response_json: the json-decoded Server Keys response object
 | 
						|
 | 
						|
            time_added_ms: the timestamp to record in server_keys_json
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Map from key_id to result object
 | 
						|
        """
 | 
						|
        ts_valid_until_ms = response_json["valid_until_ts"]
 | 
						|
 | 
						|
        # start by extracting the keys from the response, since they may be required
 | 
						|
        # to validate the signature on the response.
 | 
						|
        verify_keys = {}
 | 
						|
        for key_id, key_data in response_json["verify_keys"].items():
 | 
						|
            if is_signing_algorithm_supported(key_id):
 | 
						|
                key_base64 = key_data["key"]
 | 
						|
                key_bytes = decode_base64(key_base64)
 | 
						|
                verify_key = decode_verify_key_bytes(key_id, key_bytes)
 | 
						|
                verify_keys[key_id] = FetchKeyResult(
 | 
						|
                    verify_key=verify_key, valid_until_ts=ts_valid_until_ms
 | 
						|
                )
 | 
						|
 | 
						|
        server_name = response_json["server_name"]
 | 
						|
        verified = False
 | 
						|
        for key_id in response_json["signatures"].get(server_name, {}):
 | 
						|
            key = verify_keys.get(key_id)
 | 
						|
            if not key:
 | 
						|
                # the key may not be present in verify_keys if:
 | 
						|
                #  * we got the key from the notary server, and:
 | 
						|
                #  * the key belongs to the notary server, and:
 | 
						|
                #  * the notary server is using a different key to sign notary
 | 
						|
                #    responses.
 | 
						|
                continue
 | 
						|
 | 
						|
            verify_signed_json(response_json, server_name, key.verify_key)
 | 
						|
            verified = True
 | 
						|
            break
 | 
						|
 | 
						|
        if not verified:
 | 
						|
            raise KeyLookupError(
 | 
						|
                "Key response for %s is not signed by the origin server"
 | 
						|
                % (server_name,)
 | 
						|
            )
 | 
						|
 | 
						|
        for key_id, key_data in response_json["old_verify_keys"].items():
 | 
						|
            if is_signing_algorithm_supported(key_id):
 | 
						|
                key_base64 = key_data["key"]
 | 
						|
                key_bytes = decode_base64(key_base64)
 | 
						|
                verify_key = decode_verify_key_bytes(key_id, key_bytes)
 | 
						|
                verify_keys[key_id] = FetchKeyResult(
 | 
						|
                    verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
 | 
						|
                )
 | 
						|
 | 
						|
        await self.store.store_server_keys_response(
 | 
						|
            server_name=server_name,
 | 
						|
            from_server=from_server,
 | 
						|
            ts_added_ms=time_added_ms,
 | 
						|
            verify_keys=verify_keys,
 | 
						|
            response_json=response_json,
 | 
						|
        )
 | 
						|
 | 
						|
        return verify_keys
 | 
						|
 | 
						|
 | 
						|
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
 | 
						|
    """KeyFetcher impl which fetches keys from the "perspectives" servers"""
 | 
						|
 | 
						|
    def __init__(self, hs: "HomeServer"):
 | 
						|
        super().__init__(hs)
 | 
						|
        self.clock = hs.get_clock()
 | 
						|
        self.client = hs.get_federation_http_client()
 | 
						|
        self.key_servers = hs.config.key.key_servers
 | 
						|
 | 
						|
    async def _fetch_keys(
 | 
						|
        self, keys_to_fetch: List[_FetchKeyRequest]
 | 
						|
    ) -> Dict[str, Dict[str, FetchKeyResult]]:
 | 
						|
        """see KeyFetcher._fetch_keys"""
 | 
						|
 | 
						|
        async def get_key(key_server: TrustedKeyServer) -> Dict:
 | 
						|
            try:
 | 
						|
                return await self.get_server_verify_key_v2_indirect(
 | 
						|
                    keys_to_fetch, key_server
 | 
						|
                )
 | 
						|
            except KeyLookupError as e:
 | 
						|
                logger.warning(
 | 
						|
                    "Key lookup failed from %r: %s", key_server.server_name, e
 | 
						|
                )
 | 
						|
            except Exception as e:
 | 
						|
                logger.exception(
 | 
						|
                    "Unable to get key from %r: %s %s",
 | 
						|
                    key_server.server_name,
 | 
						|
                    type(e).__name__,
 | 
						|
                    str(e),
 | 
						|
                )
 | 
						|
 | 
						|
            return {}
 | 
						|
 | 
						|
        results = await make_deferred_yieldable(
 | 
						|
            defer.gatherResults(
 | 
						|
                [run_in_background(get_key, server) for server in self.key_servers],
 | 
						|
                consumeErrors=True,
 | 
						|
            ).addErrback(unwrapFirstError)
 | 
						|
        )
 | 
						|
 | 
						|
        union_of_keys: Dict[str, Dict[str, FetchKeyResult]] = {}
 | 
						|
        for result in results:
 | 
						|
            for server_name, keys in result.items():
 | 
						|
                union_of_keys.setdefault(server_name, {}).update(keys)
 | 
						|
 | 
						|
        return union_of_keys
 | 
						|
 | 
						|
    async def get_server_verify_key_v2_indirect(
 | 
						|
        self, keys_to_fetch: List[_FetchKeyRequest], key_server: TrustedKeyServer
 | 
						|
    ) -> Dict[str, Dict[str, FetchKeyResult]]:
 | 
						|
        """
 | 
						|
        Args:
 | 
						|
            keys_to_fetch:
 | 
						|
                the keys to be fetched.
 | 
						|
 | 
						|
            key_server: notary server to query for the keys
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Map from server_name -> key_id -> FetchKeyResult
 | 
						|
 | 
						|
        Raises:
 | 
						|
            KeyLookupError if there was an error processing the entire response from
 | 
						|
                the server
 | 
						|
        """
 | 
						|
        perspective_name = key_server.server_name
 | 
						|
        logger.info(
 | 
						|
            "Requesting keys %s from notary server %s",
 | 
						|
            keys_to_fetch,
 | 
						|
            perspective_name,
 | 
						|
        )
 | 
						|
 | 
						|
        request: JsonDict = {}
 | 
						|
        for queue_value in keys_to_fetch:
 | 
						|
            # there may be multiple requests for each server, so we have to merge
 | 
						|
            # them intelligently.
 | 
						|
            request_for_server = {
 | 
						|
                key_id: {
 | 
						|
                    "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
 | 
						|
                }
 | 
						|
                for key_id in queue_value.key_ids
 | 
						|
            }
 | 
						|
            request.setdefault(queue_value.server_name, {}).update(request_for_server)
 | 
						|
 | 
						|
        logger.debug("Request to notary server %s: %s", perspective_name, request)
 | 
						|
 | 
						|
        try:
 | 
						|
            query_response = await self.client.post_json(
 | 
						|
                destination=perspective_name,
 | 
						|
                path="/_matrix/key/v2/query",
 | 
						|
                data={"server_keys": request},
 | 
						|
            )
 | 
						|
        except (NotRetryingDestination, RequestSendFailed) as e:
 | 
						|
            # these both have str() representations which we can't really improve upon
 | 
						|
            raise KeyLookupError(str(e))
 | 
						|
        except HttpResponseException as e:
 | 
						|
            raise KeyLookupError("Remote server returned an error: %s" % (e,))
 | 
						|
 | 
						|
        logger.debug(
 | 
						|
            "Response from notary server %s: %s", perspective_name, query_response
 | 
						|
        )
 | 
						|
 | 
						|
        keys: Dict[str, Dict[str, FetchKeyResult]] = {}
 | 
						|
        added_keys: Dict[Tuple[str, str], FetchKeyResult] = {}
 | 
						|
 | 
						|
        time_now_ms = self.clock.time_msec()
 | 
						|
 | 
						|
        assert isinstance(query_response, dict)
 | 
						|
        for response in query_response["server_keys"]:
 | 
						|
            # do this first, so that we can give useful errors thereafter
 | 
						|
            server_name = response.get("server_name")
 | 
						|
            if not isinstance(server_name, str):
 | 
						|
                raise KeyLookupError(
 | 
						|
                    "Malformed response from key notary server %s: invalid server_name"
 | 
						|
                    % (perspective_name,)
 | 
						|
                )
 | 
						|
 | 
						|
            try:
 | 
						|
                self._validate_perspectives_response(key_server, response)
 | 
						|
 | 
						|
                processed_response = await self.process_v2_response(
 | 
						|
                    perspective_name, response, time_added_ms=time_now_ms
 | 
						|
                )
 | 
						|
            except KeyLookupError as e:
 | 
						|
                logger.warning(
 | 
						|
                    "Error processing response from key notary server %s for origin "
 | 
						|
                    "server %s: %s",
 | 
						|
                    perspective_name,
 | 
						|
                    server_name,
 | 
						|
                    e,
 | 
						|
                )
 | 
						|
                # we continue to process the rest of the response
 | 
						|
                continue
 | 
						|
 | 
						|
            for key_id, key in processed_response.items():
 | 
						|
                dict_key = (server_name, key_id)
 | 
						|
                if dict_key in added_keys:
 | 
						|
                    already_present_key = added_keys[dict_key]
 | 
						|
                    logger.warning(
 | 
						|
                        "Duplicate server keys for %s (%s) from perspective %s (%r, %r)",
 | 
						|
                        server_name,
 | 
						|
                        key_id,
 | 
						|
                        perspective_name,
 | 
						|
                        already_present_key,
 | 
						|
                        key,
 | 
						|
                    )
 | 
						|
 | 
						|
                    if already_present_key.valid_until_ts > key.valid_until_ts:
 | 
						|
                        # Favour the entry with the largest valid_until_ts,
 | 
						|
                        # as `old_verify_keys` are also collected from this
 | 
						|
                        # response.
 | 
						|
                        continue
 | 
						|
 | 
						|
                added_keys[dict_key] = key
 | 
						|
 | 
						|
            keys.setdefault(server_name, {}).update(processed_response)
 | 
						|
 | 
						|
        return keys
 | 
						|
 | 
						|
    def _validate_perspectives_response(
 | 
						|
        self, key_server: TrustedKeyServer, response: JsonDict
 | 
						|
    ) -> None:
 | 
						|
        """Optionally check the signature on the result of a /key/query request
 | 
						|
 | 
						|
        Args:
 | 
						|
            key_server: the notary server that produced this result
 | 
						|
 | 
						|
            response: the json-decoded Server Keys response object
 | 
						|
        """
 | 
						|
        perspective_name = key_server.server_name
 | 
						|
        perspective_keys = key_server.verify_keys
 | 
						|
 | 
						|
        if perspective_keys is None:
 | 
						|
            # signature checking is disabled on this server
 | 
						|
            return
 | 
						|
 | 
						|
        if (
 | 
						|
            "signatures" not in response
 | 
						|
            or perspective_name not in response["signatures"]
 | 
						|
        ):
 | 
						|
            raise KeyLookupError("Response not signed by the notary server")
 | 
						|
 | 
						|
        verified = False
 | 
						|
        for key_id in response["signatures"][perspective_name]:
 | 
						|
            if key_id in perspective_keys:
 | 
						|
                verify_signed_json(response, perspective_name, perspective_keys[key_id])
 | 
						|
                verified = True
 | 
						|
 | 
						|
        if not verified:
 | 
						|
            raise KeyLookupError(
 | 
						|
                "Response not signed with a known key: signed with: %r, known keys: %r"
 | 
						|
                % (
 | 
						|
                    list(response["signatures"][perspective_name].keys()),
 | 
						|
                    list(perspective_keys.keys()),
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
 | 
						|
class ServerKeyFetcher(BaseV2KeyFetcher):
 | 
						|
    """KeyFetcher impl which fetches keys from the origin servers"""
 | 
						|
 | 
						|
    def __init__(self, hs: "HomeServer"):
 | 
						|
        super().__init__(hs)
 | 
						|
        self.clock = hs.get_clock()
 | 
						|
        self.client = hs.get_federation_http_client()
 | 
						|
 | 
						|
    async def get_keys(
 | 
						|
        self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
 | 
						|
    ) -> Dict[str, FetchKeyResult]:
 | 
						|
        results = await self._queue.add_to_queue(
 | 
						|
            _FetchKeyRequest(
 | 
						|
                server_name=server_name,
 | 
						|
                key_ids=key_ids,
 | 
						|
                minimum_valid_until_ts=minimum_valid_until_ts,
 | 
						|
            ),
 | 
						|
            key=server_name,
 | 
						|
        )
 | 
						|
        return results.get(server_name, {})
 | 
						|
 | 
						|
    async def _fetch_keys(
 | 
						|
        self, keys_to_fetch: List[_FetchKeyRequest]
 | 
						|
    ) -> Dict[str, Dict[str, FetchKeyResult]]:
 | 
						|
        """
 | 
						|
        Args:
 | 
						|
            keys_to_fetch:
 | 
						|
                the keys to be fetched. server_name -> key_ids
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Map from server_name -> key_id -> FetchKeyResult
 | 
						|
        """
 | 
						|
 | 
						|
        results = {}
 | 
						|
 | 
						|
        async def get_keys(key_to_fetch_item: _FetchKeyRequest) -> None:
 | 
						|
            server_name = key_to_fetch_item.server_name
 | 
						|
 | 
						|
            try:
 | 
						|
                keys = await self.get_server_verify_keys_v2_direct(server_name)
 | 
						|
                results[server_name] = keys
 | 
						|
            except KeyLookupError as e:
 | 
						|
                logger.warning("Error looking up keys from %s: %s", server_name, e)
 | 
						|
            except Exception:
 | 
						|
                logger.exception("Error getting keys from %s", server_name)
 | 
						|
 | 
						|
        await yieldable_gather_results(get_keys, keys_to_fetch)
 | 
						|
        return results
 | 
						|
 | 
						|
    async def get_server_verify_keys_v2_direct(
 | 
						|
        self, server_name: str
 | 
						|
    ) -> Dict[str, FetchKeyResult]:
 | 
						|
        """
 | 
						|
 | 
						|
        Args:
 | 
						|
            server_name: Server to request keys from
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Map from key ID to lookup result
 | 
						|
 | 
						|
        Raises:
 | 
						|
            KeyLookupError if there was a problem making the lookup
 | 
						|
        """
 | 
						|
        time_now_ms = self.clock.time_msec()
 | 
						|
        try:
 | 
						|
            response = await self.client.get_json(
 | 
						|
                destination=server_name,
 | 
						|
                path="/_matrix/key/v2/server",
 | 
						|
                ignore_backoff=True,
 | 
						|
                # we only give the remote server 10s to respond. It should be an
 | 
						|
                # easy request to handle, so if it doesn't reply within 10s, it's
 | 
						|
                # probably not going to.
 | 
						|
                #
 | 
						|
                # Furthermore, when we are acting as a notary server, we cannot
 | 
						|
                # wait all day for all of the origin servers, as the requesting
 | 
						|
                # server will otherwise time out before we can respond.
 | 
						|
                #
 | 
						|
                # (Note that get_json may make 4 attempts, so this can still take
 | 
						|
                # almost 45 seconds to fetch the headers, plus up to another 60s to
 | 
						|
                # read the response).
 | 
						|
                timeout=10000,
 | 
						|
            )
 | 
						|
        except (NotRetryingDestination, RequestSendFailed) as e:
 | 
						|
            # these both have str() representations which we can't really improve
 | 
						|
            # upon
 | 
						|
            raise KeyLookupError(str(e))
 | 
						|
        except HttpResponseException as e:
 | 
						|
            raise KeyLookupError("Remote server returned an error: %s" % (e,))
 | 
						|
 | 
						|
        assert isinstance(response, dict)
 | 
						|
        if response["server_name"] != server_name:
 | 
						|
            raise KeyLookupError(
 | 
						|
                "Expected a response for server %r not %r"
 | 
						|
                % (server_name, response["server_name"])
 | 
						|
            )
 | 
						|
 | 
						|
        return await self.process_v2_response(
 | 
						|
            from_server=server_name,
 | 
						|
            response_json=response,
 | 
						|
            time_added_ms=time_now_ms,
 | 
						|
        )
 |