303 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			303 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
# Copyright 2015, 2016 OpenMarket Ltd
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
 | 
						|
import logging
 | 
						|
import re
 | 
						|
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
 | 
						|
 | 
						|
from synapse._pydantic_compat import HAS_PYDANTIC_V2
 | 
						|
 | 
						|
if TYPE_CHECKING or HAS_PYDANTIC_V2:
 | 
						|
    from pydantic.v1 import Extra, StrictInt, StrictStr
 | 
						|
else:
 | 
						|
    from pydantic import StrictInt, StrictStr, Extra
 | 
						|
 | 
						|
from signedjson.sign import sign_json
 | 
						|
 | 
						|
from twisted.web.server import Request
 | 
						|
 | 
						|
from synapse.crypto.keyring import ServerKeyFetcher
 | 
						|
from synapse.http.server import HttpServer
 | 
						|
from synapse.http.servlet import (
 | 
						|
    RestServlet,
 | 
						|
    parse_and_validate_json_object_from_request,
 | 
						|
    parse_integer,
 | 
						|
)
 | 
						|
from synapse.rest.models import RequestBodyModel
 | 
						|
from synapse.storage.keys import FetchKeyResultForRemote
 | 
						|
from synapse.types import JsonDict
 | 
						|
from synapse.util import json_decoder
 | 
						|
from synapse.util.async_helpers import yieldable_gather_results
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    from synapse.server import HomeServer
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
 | 
						|
class _KeyQueryCriteriaDataModel(RequestBodyModel):
 | 
						|
    class Config:
 | 
						|
        extra = Extra.allow
 | 
						|
 | 
						|
    minimum_valid_until_ts: Optional[StrictInt]
 | 
						|
 | 
						|
 | 
						|
class RemoteKey(RestServlet):
 | 
						|
    """HTTP resource for retrieving the TLS certificate and NACL signature
 | 
						|
    verification keys for a collection of servers. Checks that the reported
 | 
						|
    X.509 TLS certificate matches the one used in the HTTPS connection. Checks
 | 
						|
    that the NACL signature for the remote server is valid. Returns a dict of
 | 
						|
    JSON signed by both the remote server and by this server.
 | 
						|
 | 
						|
    Supports individual GET APIs and a bulk query POST API.
 | 
						|
 | 
						|
    Requests:
 | 
						|
 | 
						|
    GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
 | 
						|
 | 
						|
    GET /_matrix/key/v2/query/remote.server.example.com/a.key.id HTTP/1.1
 | 
						|
 | 
						|
    POST /_matrix/v2/query HTTP/1.1
 | 
						|
    Content-Type: application/json
 | 
						|
    {
 | 
						|
        "server_keys": {
 | 
						|
            "remote.server.example.com": {
 | 
						|
                "a.key.id": {
 | 
						|
                    "minimum_valid_until_ts": 1234567890123
 | 
						|
                }
 | 
						|
            }
 | 
						|
        }
 | 
						|
    }
 | 
						|
 | 
						|
    Response:
 | 
						|
 | 
						|
    HTTP/1.1 200 OK
 | 
						|
    Content-Type: application/json
 | 
						|
    {
 | 
						|
        "server_keys": [
 | 
						|
            {
 | 
						|
                "server_name": "remote.server.example.com"
 | 
						|
                "valid_until_ts": # posix timestamp
 | 
						|
                "verify_keys": {
 | 
						|
                    "a.key.id": { # The identifier for a key.
 | 
						|
                        key: "" # base64 encoded verification key.
 | 
						|
                    }
 | 
						|
                }
 | 
						|
                "old_verify_keys": {
 | 
						|
                    "an.old.key.id": { # The identifier for an old key.
 | 
						|
                        key: "", # base64 encoded key
 | 
						|
                        "expired_ts": 0, # when the key stop being used.
 | 
						|
                    }
 | 
						|
                }
 | 
						|
                "signatures": {
 | 
						|
                    "remote.server.example.com": {...}
 | 
						|
                    "this.server.example.com": {...}
 | 
						|
                }
 | 
						|
            }
 | 
						|
        ]
 | 
						|
    }
 | 
						|
    """
 | 
						|
 | 
						|
    CATEGORY = "Federation requests"
 | 
						|
 | 
						|
    class PostBody(RequestBodyModel):
 | 
						|
        server_keys: Dict[StrictStr, Dict[StrictStr, _KeyQueryCriteriaDataModel]]
 | 
						|
 | 
						|
    def __init__(self, hs: "HomeServer"):
 | 
						|
        self.fetcher = ServerKeyFetcher(hs)
 | 
						|
        self.store = hs.get_datastores().main
 | 
						|
        self.clock = hs.get_clock()
 | 
						|
        self.federation_domain_whitelist = (
 | 
						|
            hs.config.federation.federation_domain_whitelist
 | 
						|
        )
 | 
						|
        self.config = hs.config
 | 
						|
 | 
						|
    def register(self, http_server: HttpServer) -> None:
 | 
						|
        http_server.register_paths(
 | 
						|
            "GET",
 | 
						|
            (
 | 
						|
                re.compile(
 | 
						|
                    "^/_matrix/key/v2/query/(?P<server>[^/]*)(/(?P<key_id>[^/]*))?$"
 | 
						|
                ),
 | 
						|
            ),
 | 
						|
            self.on_GET,
 | 
						|
            self.__class__.__name__,
 | 
						|
        )
 | 
						|
        http_server.register_paths(
 | 
						|
            "POST",
 | 
						|
            (re.compile("^/_matrix/key/v2/query$"),),
 | 
						|
            self.on_POST,
 | 
						|
            self.__class__.__name__,
 | 
						|
        )
 | 
						|
 | 
						|
    async def on_GET(
 | 
						|
        self, request: Request, server: str, key_id: Optional[str] = None
 | 
						|
    ) -> Tuple[int, JsonDict]:
 | 
						|
        if server and key_id:
 | 
						|
            # Matrix 1.6 drops support for passing the key_id, this is incompatible
 | 
						|
            # with earlier versions and is allowed in order to support both.
 | 
						|
            # A warning is issued to help determine when it is safe to drop this.
 | 
						|
            logger.warning(
 | 
						|
                "Request for remote server key with deprecated key ID (logging to determine usage level for future removal): %s / %s",
 | 
						|
                server,
 | 
						|
                key_id,
 | 
						|
            )
 | 
						|
 | 
						|
            minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
 | 
						|
            query = {
 | 
						|
                server: {
 | 
						|
                    key_id: _KeyQueryCriteriaDataModel(
 | 
						|
                        minimum_valid_until_ts=minimum_valid_until_ts
 | 
						|
                    )
 | 
						|
                }
 | 
						|
            }
 | 
						|
        else:
 | 
						|
            query = {server: {}}
 | 
						|
 | 
						|
        return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
 | 
						|
 | 
						|
    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
 | 
						|
        content = parse_and_validate_json_object_from_request(request, self.PostBody)
 | 
						|
 | 
						|
        query = content.server_keys
 | 
						|
 | 
						|
        return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
 | 
						|
 | 
						|
    async def query_keys(
 | 
						|
        self,
 | 
						|
        query: Dict[str, Dict[str, _KeyQueryCriteriaDataModel]],
 | 
						|
        query_remote_on_cache_miss: bool = False,
 | 
						|
    ) -> JsonDict:
 | 
						|
        logger.info("Handling query for keys %r", query)
 | 
						|
 | 
						|
        server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
 | 
						|
        for server_name, key_ids in query.items():
 | 
						|
            if key_ids:
 | 
						|
                results: Mapping[
 | 
						|
                    str, Optional[FetchKeyResultForRemote]
 | 
						|
                ] = await self.store.get_server_keys_json_for_remote(
 | 
						|
                    server_name, key_ids
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                results = await self.store.get_all_server_keys_json_for_remote(
 | 
						|
                    server_name
 | 
						|
                )
 | 
						|
 | 
						|
            server_keys.update(
 | 
						|
                ((server_name, key_id), res) for key_id, res in results.items()
 | 
						|
            )
 | 
						|
 | 
						|
        json_results: Set[bytes] = set()
 | 
						|
 | 
						|
        time_now_ms = self.clock.time_msec()
 | 
						|
 | 
						|
        # Map server_name->key_id->int. Note that the value of the int is unused.
 | 
						|
        # XXX: why don't we just use a set?
 | 
						|
        cache_misses: Dict[str, Dict[str, int]] = {}
 | 
						|
        for (server_name, key_id), key_result in server_keys.items():
 | 
						|
            if not query[server_name]:
 | 
						|
                # all keys were requested. Just return what we have without worrying
 | 
						|
                # about validity
 | 
						|
                if key_result:
 | 
						|
                    json_results.add(key_result.key_json)
 | 
						|
                continue
 | 
						|
 | 
						|
            miss = False
 | 
						|
            if key_result is None:
 | 
						|
                miss = True
 | 
						|
            else:
 | 
						|
                ts_added_ms = key_result.added_ts
 | 
						|
                ts_valid_until_ms = key_result.valid_until_ts
 | 
						|
                req_key = query.get(server_name, {}).get(
 | 
						|
                    key_id, _KeyQueryCriteriaDataModel(minimum_valid_until_ts=None)
 | 
						|
                )
 | 
						|
                req_valid_until = req_key.minimum_valid_until_ts
 | 
						|
                if req_valid_until is not None:
 | 
						|
                    if ts_valid_until_ms < req_valid_until:
 | 
						|
                        logger.debug(
 | 
						|
                            "Cached response for %r/%r is older than requested"
 | 
						|
                            ": valid_until (%r) < minimum_valid_until (%r)",
 | 
						|
                            server_name,
 | 
						|
                            key_id,
 | 
						|
                            ts_valid_until_ms,
 | 
						|
                            req_valid_until,
 | 
						|
                        )
 | 
						|
                        miss = True
 | 
						|
                    else:
 | 
						|
                        logger.debug(
 | 
						|
                            "Cached response for %r/%r is newer than requested"
 | 
						|
                            ": valid_until (%r) >= minimum_valid_until (%r)",
 | 
						|
                            server_name,
 | 
						|
                            key_id,
 | 
						|
                            ts_valid_until_ms,
 | 
						|
                            req_valid_until,
 | 
						|
                        )
 | 
						|
                elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
 | 
						|
                    logger.debug(
 | 
						|
                        "Cached response for %r/%r is too old"
 | 
						|
                        ": (added (%r) + valid_until (%r)) / 2 < now (%r)",
 | 
						|
                        server_name,
 | 
						|
                        key_id,
 | 
						|
                        ts_added_ms,
 | 
						|
                        ts_valid_until_ms,
 | 
						|
                        time_now_ms,
 | 
						|
                    )
 | 
						|
                    # We more than half way through the lifetime of the
 | 
						|
                    # response. We should fetch a fresh copy.
 | 
						|
                    miss = True
 | 
						|
                else:
 | 
						|
                    logger.debug(
 | 
						|
                        "Cached response for %r/%r is still valid"
 | 
						|
                        ": (added (%r) + valid_until (%r)) / 2 < now (%r)",
 | 
						|
                        server_name,
 | 
						|
                        key_id,
 | 
						|
                        ts_added_ms,
 | 
						|
                        ts_valid_until_ms,
 | 
						|
                        time_now_ms,
 | 
						|
                    )
 | 
						|
 | 
						|
                json_results.add(key_result.key_json)
 | 
						|
 | 
						|
            if miss and query_remote_on_cache_miss:
 | 
						|
                # only bother attempting to fetch keys from servers on our whitelist
 | 
						|
                if (
 | 
						|
                    self.federation_domain_whitelist is None
 | 
						|
                    or server_name in self.federation_domain_whitelist
 | 
						|
                ):
 | 
						|
                    cache_misses.setdefault(server_name, {})[key_id] = 0
 | 
						|
 | 
						|
        # If there is a cache miss, request the missing keys, then recurse (and
 | 
						|
        # ensure the result is sent).
 | 
						|
        if cache_misses:
 | 
						|
            await yieldable_gather_results(
 | 
						|
                lambda t: self.fetcher.get_keys(*t),
 | 
						|
                (
 | 
						|
                    (server_name, list(keys), 0)
 | 
						|
                    for server_name, keys in cache_misses.items()
 | 
						|
                ),
 | 
						|
            )
 | 
						|
            return await self.query_keys(query, query_remote_on_cache_miss=False)
 | 
						|
        else:
 | 
						|
            signed_keys = []
 | 
						|
            for key_json_raw in json_results:
 | 
						|
                key_json = json_decoder.decode(key_json_raw.decode("utf-8"))
 | 
						|
                for signing_key in self.config.key.key_server_signing_keys:
 | 
						|
                    key_json = sign_json(
 | 
						|
                        key_json, self.config.server.server_name, signing_key
 | 
						|
                    )
 | 
						|
 | 
						|
                signed_keys.append(key_json)
 | 
						|
 | 
						|
            return {"server_keys": signed_keys}
 |