Add cache to `get_server_keys_json_for_remote` (#16123)
parent
54a51ff6c1
commit
0aba4a4eaa
|
@ -0,0 +1 @@
|
|||
Add cache to `get_server_keys_json_for_remote`.
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
|
||||
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
|
@ -27,6 +27,7 @@ from synapse.http.servlet import (
|
|||
parse_integer,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.storage.keys import FetchKeyResultForRemote
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import yieldable_gather_results
|
||||
|
@ -157,14 +158,22 @@ class RemoteKey(RestServlet):
|
|||
) -> JsonDict:
|
||||
logger.info("Handling query for keys %r", query)
|
||||
|
||||
store_queries = []
|
||||
server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
|
||||
for server_name, key_ids in query.items():
|
||||
if not key_ids:
|
||||
key_ids = (None,)
|
||||
for key_id in key_ids:
|
||||
store_queries.append((server_name, key_id, None))
|
||||
if key_ids:
|
||||
results: Mapping[
|
||||
str, Optional[FetchKeyResultForRemote]
|
||||
] = await self.store.get_server_keys_json_for_remote(
|
||||
server_name, key_ids
|
||||
)
|
||||
else:
|
||||
results = await self.store.get_all_server_keys_json_for_remote(
|
||||
server_name
|
||||
)
|
||||
|
||||
cached = await self.store.get_server_keys_json_for_remote(store_queries)
|
||||
server_keys.update(
|
||||
((server_name, key_id), res) for key_id, res in results.items()
|
||||
)
|
||||
|
||||
json_results: Set[bytes] = set()
|
||||
|
||||
|
@ -173,23 +182,20 @@ class RemoteKey(RestServlet):
|
|||
# Map server_name->key_id->int. Note that the value of the int is unused.
|
||||
# XXX: why don't we just use a set?
|
||||
cache_misses: Dict[str, Dict[str, int]] = {}
|
||||
for (server_name, key_id, _), key_results in cached.items():
|
||||
results = [(result["ts_added_ms"], result) for result in key_results]
|
||||
|
||||
if key_id is None:
|
||||
for (server_name, key_id), key_result in server_keys.items():
|
||||
if not query[server_name]:
|
||||
# all keys were requested. Just return what we have without worrying
|
||||
# about validity
|
||||
for _, result in results:
|
||||
# Cast to bytes since postgresql returns a memoryview.
|
||||
json_results.add(bytes(result["key_json"]))
|
||||
if key_result:
|
||||
json_results.add(key_result.key_json)
|
||||
continue
|
||||
|
||||
miss = False
|
||||
if not results:
|
||||
if key_result is None:
|
||||
miss = True
|
||||
else:
|
||||
ts_added_ms, most_recent_result = max(results)
|
||||
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
|
||||
ts_added_ms = key_result.added_ts
|
||||
ts_valid_until_ms = key_result.valid_until_ts
|
||||
req_key = query.get(server_name, {}).get(key_id, {})
|
||||
req_valid_until = req_key.get("minimum_valid_until_ts")
|
||||
if req_valid_until is not None:
|
||||
|
@ -235,8 +241,8 @@ class RemoteKey(RestServlet):
|
|||
ts_valid_until_ms,
|
||||
time_now_ms,
|
||||
)
|
||||
# Cast to bytes since postgresql returns a memoryview.
|
||||
json_results.add(bytes(most_recent_result["key_json"]))
|
||||
|
||||
json_results.add(key_result.key_json)
|
||||
|
||||
if miss and query_remote_on_cache_miss:
|
||||
# only bother attempting to fetch keys from servers on our whitelist
|
||||
|
|
|
@ -16,14 +16,13 @@
|
|||
import itertools
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
from typing import Dict, Iterable, Mapping, Optional, Tuple
|
||||
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.keys import FetchKeyResult
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
|
|||
db_binary_type = memoryview
|
||||
|
||||
|
||||
class KeyStore(SQLBaseStore):
|
||||
class KeyStore(CacheInvalidationWorkerStore):
|
||||
"""Persistence for signature verification keys"""
|
||||
|
||||
@cached()
|
||||
|
@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
|
|||
# invalidate takes a tuple corresponding to the params of
|
||||
# _get_server_keys_json. _get_server_keys_json only takes one
|
||||
# param, which is itself the 2-tuple (server_name, key_id).
|
||||
self._get_server_keys_json.invalidate(((server_name, key_id),))
|
||||
await self.invalidate_cache_and_stream(
|
||||
"_get_server_keys_json", ((server_name, key_id),)
|
||||
)
|
||||
await self.invalidate_cache_and_stream(
|
||||
"get_server_key_json_for_remote", (server_name, key_id)
|
||||
)
|
||||
|
||||
@cached()
|
||||
def _get_server_keys_json(
|
||||
|
@ -253,36 +257,29 @@ class KeyStore(SQLBaseStore):
|
|||
|
||||
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
|
||||
|
||||
@cached()
|
||||
def get_server_key_json_for_remote(
|
||||
self,
|
||||
server_name: str,
|
||||
key_id: str,
|
||||
) -> Optional[FetchKeyResultForRemote]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
|
||||
)
|
||||
async def get_server_keys_json_for_remote(
|
||||
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
|
||||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
|
||||
"""Retrieve the key json for a list of server_keys and key ids.
|
||||
If no keys are found for a given server, key_id and source then
|
||||
that server, key_id, and source triplet entry will be an empty list.
|
||||
The JSON is returned as a byte array so that it can be efficiently
|
||||
used in an HTTP response.
|
||||
self, server_name: str, key_ids: Iterable[str]
|
||||
) -> Dict[str, Optional[FetchKeyResultForRemote]]:
|
||||
"""Fetch the cached keys for the given server/key IDs.
|
||||
|
||||
Args:
|
||||
server_keys: List of (server_name, key_id, source) triplets.
|
||||
|
||||
Returns:
|
||||
A mapping from (server_name, key_id, source) triplets to a list of dicts
|
||||
If we have multiple entries for a given key ID, returns the most recent.
|
||||
"""
|
||||
|
||||
def _get_server_keys_json_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
|
||||
results = {}
|
||||
for server_name, key_id, from_server in server_keys:
|
||||
keyvalues = {"server_name": server_name}
|
||||
if key_id is not None:
|
||||
keyvalues["key_id"] = key_id
|
||||
if from_server is not None:
|
||||
keyvalues["from_server"] = from_server
|
||||
rows = self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
"server_keys_json",
|
||||
keyvalues=keyvalues,
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="server_keys_json",
|
||||
column="key_id",
|
||||
iterable=key_ids,
|
||||
keyvalues={"server_name": server_name},
|
||||
retcols=(
|
||||
"key_id",
|
||||
"from_server",
|
||||
|
@ -290,10 +287,57 @@ class KeyStore(SQLBaseStore):
|
|||
"ts_valid_until_ms",
|
||||
"key_json",
|
||||
),
|
||||
desc="get_server_keys_json_for_remote",
|
||||
)
|
||||
results[(server_name, key_id, from_server)] = rows
|
||||
return results
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_server_keys_json", _get_server_keys_json_txn
|
||||
if not rows:
|
||||
return {}
|
||||
|
||||
# We sort the rows so that the most recently added entry is picked up.
|
||||
rows.sort(key=lambda r: r["ts_added_ms"])
|
||||
|
||||
return {
|
||||
row["key_id"]: FetchKeyResultForRemote(
|
||||
# Cast to bytes since postgresql returns a memoryview.
|
||||
key_json=bytes(row["key_json"]),
|
||||
valid_until_ts=row["ts_valid_until_ms"],
|
||||
added_ts=row["ts_added_ms"],
|
||||
)
|
||||
for row in rows
|
||||
}
|
||||
|
||||
async def get_all_server_keys_json_for_remote(
|
||||
self,
|
||||
server_name: str,
|
||||
) -> Dict[str, FetchKeyResultForRemote]:
|
||||
"""Fetch the cached keys for the given server.
|
||||
|
||||
If we have multiple entries for a given key ID, returns the most recent.
|
||||
"""
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
table="server_keys_json",
|
||||
keyvalues={"server_name": server_name},
|
||||
retcols=(
|
||||
"key_id",
|
||||
"from_server",
|
||||
"ts_added_ms",
|
||||
"ts_valid_until_ms",
|
||||
"key_json",
|
||||
),
|
||||
desc="get_server_keys_json_for_remote",
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return {}
|
||||
|
||||
rows.sort(key=lambda r: r["ts_added_ms"])
|
||||
|
||||
return {
|
||||
row["key_id"]: FetchKeyResultForRemote(
|
||||
# Cast to bytes since postgresql returns a memoryview.
|
||||
key_json=bytes(row["key_json"]),
|
||||
valid_until_ts=row["ts_valid_until_ms"],
|
||||
added_ts=row["ts_added_ms"],
|
||||
)
|
||||
for row in rows
|
||||
}
|
||||
|
|
|
@ -25,3 +25,10 @@ logger = logging.getLogger(__name__)
|
|||
class FetchKeyResult:
|
||||
verify_key: VerifyKey # the key itself
|
||||
valid_until_ts: int # how long we can use this key for
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class FetchKeyResultForRemote:
|
||||
key_json: bytes # the full key JSON
|
||||
valid_until_ts: int # how long we can use this key for, in milliseconds.
|
||||
added_ts: int # When we added this key, in milliseconds.
|
||||
|
|
|
@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(k.verify_key.version, "ver1")
|
||||
|
||||
# check that the perspectives store is correctly updated
|
||||
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
|
||||
key_json = self.get_success(
|
||||
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
||||
[lookup_triplet]
|
||||
SERVER_NAME, [testverifykey_id]
|
||||
)
|
||||
)
|
||||
res_keys = key_json[lookup_triplet]
|
||||
self.assertEqual(len(res_keys), 1)
|
||||
res = res_keys[0]
|
||||
self.assertEqual(res["key_id"], testverifykey_id)
|
||||
self.assertEqual(res["from_server"], SERVER_NAME)
|
||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
||||
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
|
||||
res = key_json[testverifykey_id]
|
||||
self.assertIsNotNone(res)
|
||||
assert res is not None
|
||||
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
|
||||
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
|
||||
|
||||
# we expect it to be encoded as canonical json *before* it hits the db
|
||||
self.assertEqual(
|
||||
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
|
||||
)
|
||||
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
|
||||
|
||||
# change the server name: the result should be ignored
|
||||
response["server_name"] = "OTHER_SERVER"
|
||||
|
@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(k.verify_key.version, "ver1")
|
||||
|
||||
# check that the perspectives store is correctly updated
|
||||
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
|
||||
key_json = self.get_success(
|
||||
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
||||
[lookup_triplet]
|
||||
SERVER_NAME, [testverifykey_id]
|
||||
)
|
||||
)
|
||||
res_keys = key_json[lookup_triplet]
|
||||
self.assertEqual(len(res_keys), 1)
|
||||
res = res_keys[0]
|
||||
self.assertEqual(res["key_id"], testverifykey_id)
|
||||
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
|
||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
||||
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
|
||||
res = key_json[testverifykey_id]
|
||||
self.assertIsNotNone(res)
|
||||
assert res is not None
|
||||
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
|
||||
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
|
||||
|
||||
self.assertEqual(
|
||||
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
|
||||
)
|
||||
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
|
||||
|
||||
def test_get_multiple_keys_from_perspectives(self) -> None:
|
||||
"""Check that we can correctly request multiple keys for the same server"""
|
||||
|
@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(k.verify_key.version, "ver1")
|
||||
|
||||
# check that the perspectives store is correctly updated
|
||||
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
|
||||
key_json = self.get_success(
|
||||
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
||||
[lookup_triplet]
|
||||
SERVER_NAME, [testverifykey_id]
|
||||
)
|
||||
)
|
||||
res_keys = key_json[lookup_triplet]
|
||||
self.assertEqual(len(res_keys), 1)
|
||||
res = res_keys[0]
|
||||
self.assertEqual(res["key_id"], testverifykey_id)
|
||||
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
|
||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
||||
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
|
||||
res = key_json[testverifykey_id]
|
||||
self.assertIsNotNone(res)
|
||||
assert res is not None
|
||||
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
|
||||
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
|
||||
|
||||
self.assertEqual(
|
||||
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
|
||||
)
|
||||
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
|
||||
|
||||
def test_invalid_perspectives_responses(self) -> None:
|
||||
"""Check that invalid responses from the perspectives server are rejected"""
|
||||
|
|
Loading…
Reference in New Issue