Refactor storing of server keys (#16261)

pull/16309/head
Erik Johnston 2023-09-12 11:08:04 +01:00 committed by GitHub
parent 9400dc0535
commit 2b35626b6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 111 additions and 370 deletions

1
changelog.d/16261.misc Normal file
View File

@ -0,0 +1 @@
Simplify server key storage.

View File

@ -23,12 +23,7 @@ from signedjson.key import (
get_verify_key,
is_signing_algorithm_supported,
)
from signedjson.sign import (
SignatureVerifyException,
encode_canonical_json,
signature_ids,
verify_signed_json,
)
from signedjson.sign import SignatureVerifyException, signature_ids, verify_signed_json
from signedjson.types import VerifyKey
from unpaddedbase64 import decode_base64
@ -596,24 +591,12 @@ class BaseV2KeyFetcher(KeyFetcher):
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)
key_json_bytes = encode_canonical_json(response_json)
await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=key_json_bytes,
)
for key_id in verify_keys
],
consumeErrors=True,
).addErrback(unwrapFirstError)
await self.store.store_server_keys_response(
server_name=server_name,
from_server=from_server,
ts_added_ms=time_added_ms,
verify_keys=verify_keys,
response_json=response_json,
)
return verify_keys
@ -775,10 +758,6 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
keys.setdefault(server_name, {}).update(processed_response)
await self.store.store_server_signature_keys(
perspective_name, time_now_ms, added_keys
)
return keys
def _validate_perspectives_response(

View File

@ -16,14 +16,17 @@
import itertools
import json
import logging
from typing import Dict, Iterable, Mapping, Optional, Tuple
from typing import Dict, Iterable, Optional, Tuple
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@ -36,162 +39,84 @@ db_binary_type = memoryview
class KeyStore(CacheInvalidationWorkerStore):
"""Persistence for signature verification keys"""
@cached()
def _get_server_signature_key(
self, server_name_and_key_id: Tuple[str, str]
) -> FetchKeyResult:
raise NotImplementedError()
@cachedList(
cached_method_name="_get_server_signature_key",
list_name="server_name_and_key_ids",
)
async def get_server_signature_keys(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], FetchKeyResult]:
"""
Args:
server_name_and_key_ids:
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
A map from (server_name, key_id) -> FetchKeyResult, or None if the
key is unknown
"""
keys = {}
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch)
sql = """
SELECT server_name, key_id, verify_key, ts_valid_until_ms
FROM server_signature_keys WHERE 1=0
""" + " OR (server_name=? AND key_id=?)" * len(
batch
)
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
for row in txn:
server_name, key_id, key_bytes, ts_valid_until_ms = row
if ts_valid_until_ms is None:
# Old keys may be stored with a ts_valid_until_ms of null,
# in which case we treat this as if it was set to `0`, i.e.
# it won't match key requests that define a minimum
# `ts_valid_until_ms`.
ts_valid_until_ms = 0
keys[(server_name, key_id)] = FetchKeyResult(
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
valid_until_ts=ts_valid_until_ms,
)
def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
for batch in batch_iter(server_name_and_key_ids, 50):
_get_keys(txn, batch)
return keys
return await self.db_pool.runInteraction("get_server_signature_keys", _txn)
async def store_server_signature_keys(
self,
from_server: str,
ts_added_ms: int,
verify_keys: Mapping[Tuple[str, str], FetchKeyResult],
) -> None:
"""Stores NACL verification keys for remote servers.
Args:
from_server: Where the verification keys were looked up
ts_added_ms: The time to record that the key was added
verify_keys:
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
"""
key_values = []
value_values = []
invalidations = []
for (server_name, key_id), fetch_result in verify_keys.items():
key_values.append((server_name, key_id))
value_values.append(
(
from_server,
ts_added_ms,
fetch_result.valid_until_ts,
db_binary_type(fetch_result.verify_key.encode()),
)
)
# invalidate takes a tuple corresponding to the params of
# _get_server_signature_key. _get_server_signature_key only takes one
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))
await self.db_pool.simple_upsert_many(
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
value_names=(
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"verify_key",
),
value_values=value_values,
desc="store_server_signature_keys",
)
invalidate = self._get_server_signature_key.invalidate
for i in invalidations:
invalidate((i,))
async def store_server_keys_json(
async def store_server_keys_response(
self,
server_name: str,
key_id: str,
from_server: str,
ts_now_ms: int,
ts_expires_ms: int,
key_json_bytes: bytes,
ts_added_ms: int,
verify_keys: Dict[str, FetchKeyResult],
response_json: JsonDict,
) -> None:
"""Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the
(server_name, key_id, from_server) triplet if one already existed.
Args:
server_name: The name of the server.
key_id: The identifier of the key this JSON is for.
from_server: The server this JSON was fetched from.
ts_now_ms: The time now in milliseconds.
ts_valid_until_ms: The time when this json stops being valid.
key_json_bytes: The encoded JSON.
"""
await self.db_pool.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
"key_id": key_id,
"from_server": from_server,
},
values={
"server_name": server_name,
"key_id": key_id,
"from_server": from_server,
"ts_added_ms": ts_now_ms,
"ts_valid_until_ms": ts_expires_ms,
"key_json": db_binary_type(key_json_bytes),
},
desc="store_server_keys_json",
)
"""Stores the keys for the given server that we got from `from_server`.
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
await self.invalidate_cache_and_stream(
"_get_server_keys_json", ((server_name, key_id),)
)
await self.invalidate_cache_and_stream(
"get_server_key_json_for_remote", (server_name, key_id)
Args:
server_name: The owner of the keys
from_server: Which server we got the keys from
ts_added_ms: When we're adding the keys
verify_keys: The decoded keys
response_json: The full *signed* response JSON that contains the keys.
"""
key_json_bytes = encode_canonical_json(response_json)
def store_server_keys_response_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_many_txn(
txn,
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=[(server_name, key_id) for key_id in verify_keys],
value_names=(
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"verify_key",
),
value_values=[
(
from_server,
ts_added_ms,
fetch_result.valid_until_ts,
db_binary_type(fetch_result.verify_key.encode()),
)
for fetch_result in verify_keys.values()
],
)
self.db_pool.simple_upsert_many_txn(
txn,
table="server_keys_json",
key_names=("server_name", "key_id", "from_server"),
key_values=[
(server_name, key_id, from_server) for key_id in verify_keys
],
value_names=(
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
value_values=[
(
ts_added_ms,
fetch_result.valid_until_ts,
db_binary_type(key_json_bytes),
)
for fetch_result in verify_keys.values()
],
)
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
for key_id in verify_keys:
self._invalidate_cache_and_stream(
txn, self._get_server_keys_json, ((server_name, key_id),)
)
self._invalidate_cache_and_stream(
txn, self.get_server_key_json_for_remote, (server_name, key_id)
)
await self.db_pool.runInteraction(
"store_server_keys_response", store_server_keys_response_txn
)
@cached()

View File

@ -13,7 +13,7 @@
# limitations under the License.
import time
from typing import Any, Dict, List, Optional, cast
from unittest.mock import AsyncMock, Mock
from unittest.mock import Mock
import attr
import canonicaljson
@ -189,23 +189,24 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_keys_json(
r = self.hs.get_datastores().main.store_server_keys_response(
"server9",
get_key_id(key1),
from_server="test",
ts_now_ms=int(time.time() * 1000),
ts_expires_ms=1000,
ts_added_ms=int(time.time() * 1000),
verify_keys={
get_key_id(key1): FetchKeyResult(
verify_key=get_verify_key(key1), valid_until_ts=1000
)
},
# The entire response gets signed & stored, just include the bits we
# care about.
key_json_bytes=canonicaljson.encode_canonical_json(
{
"verify_keys": {
get_key_id(key1): {
"key": encode_verify_key_base64(get_verify_key(key1))
}
response_json={
"verify_keys": {
get_key_id(key1): {
"key": encode_verify_key_base64(get_verify_key(key1))
}
}
),
},
)
self.get_success(r)
@ -285,34 +286,6 @@ class KeyringTestCase(unittest.HomeserverTestCase):
d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
self.get_success(d)
def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
"""Tests that we correctly handle key requests for keys we've stored
with a null `ts_valid_until_ms`
"""
mock_fetcher = Mock()
mock_fetcher.get_keys = AsyncMock(return_value={})
key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_signature_keys(
"server9",
int(time.time() * 1000),
# None is not a valid value in FetchKeyResult, but we're abusing this
# API to insert null values into the database. The nulls get converted
# to 0 when fetched in KeyStore.get_server_signature_keys.
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type]
)
self.get_success(r)
json1: JsonDict = {}
signedjson.sign.sign_json(json1, "server9", key1)
# should succeed on a signed object with a 0 minimum_valid_until_ms
d = self.hs.get_datastores().main.get_server_signature_keys(
[("server9", get_key_id(key1))]
)
result = self.get_success(d)
self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0)
def test_verify_json_dedupes_key_requests(self) -> None:
"""Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key("1")

View File

@ -1,137 +0,0 @@
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import signedjson.key
import signedjson.types
import unpaddedbase64
from synapse.storage.keys import FetchKeyResult
import tests.unittest
def decode_verify_key_base64(
key_id: str, key_base64: str
) -> signedjson.types.VerifyKey:
key_bytes = unpaddedbase64.decode_base64(key_base64)
return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
KEY_1 = decode_verify_key_base64(
"ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
)
KEY_2 = decode_verify_key_base64(
"ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
)
class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_server_signature_keys(self) -> None:
store = self.hs.get_datastores().main
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:KEY_ID_2"
self.get_success(
store.store_server_signature_keys(
"from_server",
10,
{
("server1", key_id_1): FetchKeyResult(KEY_1, 100),
("server1", key_id_2): FetchKeyResult(KEY_2, 200),
},
)
)
res = self.get_success(
store.get_server_signature_keys(
[
("server1", key_id_1),
("server1", key_id_2),
("server1", "ed25519:key3"),
]
)
)
self.assertEqual(len(res.keys()), 3)
res1 = res[("server1", key_id_1)]
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.verify_key.version, "key1")
self.assertEqual(res1.valid_until_ts, 100)
res2 = res[("server1", key_id_2)]
self.assertEqual(res2.verify_key, KEY_2)
# version comes from the ID it was stored with
self.assertEqual(res2.verify_key.version, "KEY_ID_2")
self.assertEqual(res2.valid_until_ts, 200)
# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])
def test_cache(self) -> None:
"""Check that updates correctly invalidate the cache."""
store = self.hs.get_datastores().main
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:key2"
self.get_success(
store.store_server_signature_keys(
"from_server",
0,
{
("srv1", key_id_1): FetchKeyResult(KEY_1, 100),
("srv1", key_id_2): FetchKeyResult(KEY_2, 200),
},
)
)
res = self.get_success(
store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
)
self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)]
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.valid_until_ts, 100)
res2 = res[("srv1", key_id_2)]
self.assertEqual(res2.verify_key, KEY_2)
self.assertEqual(res2.valid_until_ts, 200)
# we should be able to look up the same thing again without a db hit
res = self.get_success(store.get_server_signature_keys([("srv1", key_id_1)]))
self.assertEqual(len(res.keys()), 1)
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
new_key_2 = signedjson.key.get_verify_key(
signedjson.key.generate_signing_key("key2")
)
d = store.store_server_signature_keys(
"from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)}
)
self.get_success(d)
res = self.get_success(
store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
)
self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)]
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.valid_until_ts, 100)
res2 = res[("srv1", key_id_2)]
self.assertEqual(res2.verify_key, new_key_2)
self.assertEqual(res2.valid_until_ts, 300)

View File

@ -70,6 +70,7 @@ from synapse.logging.context import (
)
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
@ -858,23 +859,22 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
self.get_success(
hs.get_datastores().main.store_server_keys_json(
hs.get_datastores().main.store_server_keys_response(
self.OTHER_SERVER_NAME,
verify_key_id,
from_server=self.OTHER_SERVER_NAME,
ts_now_ms=clock.time_msec(),
ts_expires_ms=clock.time_msec() + 10000,
key_json_bytes=canonicaljson.encode_canonical_json(
{
"verify_keys": {
verify_key_id: {
"key": signedjson.key.encode_verify_key_base64(
verify_key
)
}
ts_added_ms=clock.time_msec(),
verify_keys={
verify_key_id: FetchKeyResult(
verify_key=verify_key, valid_until_ts=clock.time_msec() + 10000
),
},
response_json={
"verify_keys": {
verify_key_id: {
"key": signedjson.key.encode_verify_key_base64(verify_key)
}
}
),
},
)
)