Bulk-invalidate e2e cached queries after claiming keys (#16613)

Co-authored-by: Patrick Cloke <patrickc@matrix.org>
pull/16616/head
David Robertson 2023-11-09 15:57:09 +00:00 committed by GitHub
parent f6aa047aa2
commit 91587d4cf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 249 additions and 28 deletions

View File

@ -0,0 +1 @@
Improve the performance of claiming encryption keys in multi-worker deployments.

View File

@ -1116,7 +1116,7 @@ class DatabasePool:
def simple_insert_many_txn( def simple_insert_many_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
table: str, table: str,
keys: Collection[str], keys: Sequence[str],
values: Collection[Iterable[Any]], values: Collection[Iterable[Any]],
) -> None: ) -> None:
"""Executes an INSERT query on the named table. """Executes an INSERT query on the named table.

View File

@ -483,6 +483,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys) txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_cache_and_stream_bulk(
self,
txn: LoggingTransaction,
cache_func: CachedFunction,
key_tuples: Collection[Tuple[Any, ...]],
) -> None:
"""A bulk version of _invalidate_cache_and_stream.
Locally invalidate every key-tuple in `key_tuples`, then emit invalidations
for each key-tuple over replication.
This implementation is more efficient than a loop which repeatedly calls the
non-bulk version.
"""
if not key_tuples:
return
for keys in key_tuples:
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication_bulk(
txn, cache_func.__name__, key_tuples
)
def _invalidate_all_cache_and_stream( def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: CachedFunction self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None: ) -> None:
@ -564,10 +588,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None assert self._cache_id_gen is not None
# get_next() returns a context manager which is designed to wrap
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
stream_id = self._cache_id_gen.get_next_txn(txn) stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data) txn.call_after(self.hs.get_notifier().on_new_replication_data)
@ -586,6 +606,53 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
}, },
) )
def _send_invalidation_to_replication_bulk(
self,
txn: LoggingTransaction,
cache_name: str,
key_tuples: Collection[Tuple[Any, ...]],
) -> None:
"""Announce the invalidation of multiple (but not all) cache entries.
This is more efficient than repeated calls to the non-bulk version. It should
NOT be used to invalidating the entire cache: use
`_send_invalidation_to_replication` with keys=None.
Note that this does *not* invalidate the cache locally.
Args:
txn
cache_name
key_tuples: Key-tuples to invalidate. Assumed to be non-empty.
"""
if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None
stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples))
ts = self._clock.time_msec()
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self.db_pool.simple_insert_many_txn(
txn,
table="cache_invalidation_stream_by_instance",
keys=(
"stream_id",
"instance_name",
"cache_func",
"keys",
"invalidation_ts",
),
values=[
# We convert key_tuples to a list here because psycopg2 serialises
# lists as pq arrrays, but serialises tuples as "composite types".
# (We need an array because the `keys` column has type `[]text`.)
# See:
# https://www.psycopg.org/docs/usage.html#adapt-list
# https://www.psycopg.org/docs/usage.html#adapt-tuple
(stream_id, self._instance_name, cache_name, list(key_tuple), ts)
for stream_id, key_tuple in zip(stream_ids, key_tuples)
],
)
def get_cache_stream_token_for_writer(self, instance_name: str) -> int: def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen: if self._cache_id_gen:
return self._cache_id_gen.get_current_token_for_writer(instance_name) return self._cache_id_gen.get_current_token_for_writer(instance_name)

View File

@ -1237,13 +1237,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
for user_id, device_id, algorithm, key_id, key_json in claimed_keys: for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id)) seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) self._invalidate_cache_and_stream_bulk(
) txn, self.get_e2e_unused_fallback_key_types, seen_user_device
)
return results return results
@ -1376,14 +1374,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list) List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
) )
seen_user_device: Set[Tuple[str, str]] = set() seen_user_device = {
for user_id, device_id, _, _, _ in otk_rows: (user_id, device_id) for user_id, device_id, _, _, _ in otk_rows
if (user_id, device_id) in seen_user_device: }
continue self._invalidate_cache_and_stream_bulk(
seen_user_device.add((user_id, device_id)) txn,
self._invalidate_cache_and_stream( self.count_e2e_one_time_keys,
txn, self.count_e2e_one_time_keys, (user_id, device_id) seen_user_device,
) )
return otk_rows return otk_rows

View File

@ -650,8 +650,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
next_id = self._load_next_id_txn(txn) next_id = self._load_next_id_txn(txn)
txn.call_after(self._mark_id_as_finished, next_id) txn.call_after(self._mark_ids_as_finished, [next_id])
txn.call_on_exception(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_ids_as_finished, [next_id])
txn.call_after(self._notifier.notify_replication) txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream # Update the `stream_positions` table with newly updated stream
@ -671,14 +671,50 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return self._return_factor * next_id return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int) -> None: def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]:
"""The ID has finished being processed so we should advance the """
Usage:
stream_id = stream_id_gen.get_next_txn(txn)
# ... persist event ...
"""
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer")
next_ids = self._load_next_mult_id_txn(txn, n)
txn.call_after(self._mark_ids_as_finished, next_ids)
txn.call_on_exception(self._mark_ids_as_finished, next_ids)
txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream
# ID (unless self._writers is not set in which case we don't
# bother, as nothing will read it).
#
# We only do this on the success path so that the persisted current
# position points to a persisted row with the correct instance name.
if self._writers:
txn.call_after(
run_as_background_process,
"MultiWriterIdGenerator._update_table",
self._db.runInteraction,
"MultiWriterIdGenerator._update_table",
self._update_stream_positions_table_txn,
)
return [self._return_factor * next_id for next_id in next_ids]
def _mark_ids_as_finished(self, next_ids: List[int]) -> None:
"""These IDs have finished being processed so we should advance the
current position if possible. current position if possible.
""" """
with self._lock: with self._lock:
self._unfinished_ids.discard(next_id) self._unfinished_ids.difference_update(next_ids)
self._finished_ids.add(next_id) self._finished_ids.update(next_ids)
new_cur: Optional[int] = None new_cur: Optional[int] = None
@ -727,7 +763,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
curr, new_cur, self._max_position_of_local_instance curr, new_cur, self._max_position_of_local_instance
) )
self._add_persisted_position(next_id) # TODO Can we call this for just the last position or somehow batch
# _add_persisted_position.
for next_id in next_ids:
self._add_persisted_position(next_id)
def get_current_token(self) -> int: def get_current_token(self) -> int:
return self.get_persisted_upto_position() return self.get_persisted_upto_position()
@ -933,8 +972,7 @@ class _MultiWriterCtxManager:
exc: Optional[BaseException], exc: Optional[BaseException],
tb: Optional[TracebackType], tb: Optional[TracebackType],
) -> bool: ) -> bool:
for i in self.stream_ids: self.id_gen._mark_ids_as_finished(self.stream_ids)
self.id_gen._mark_id_as_finished(i)
self.notifier.notify_replication() self.notifier.notify_replication()

View File

@ -0,0 +1,117 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock, call
from synapse.storage.database import LoggingTransaction
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.unittest import HomeserverTestCase
class CacheInvalidationTestCase(HomeserverTestCase):
def setUp(self) -> None:
super().setUp()
self.store = self.hs.get_datastores().main
def test_bulk_invalidation(self) -> None:
master_invalidate = Mock()
self.store._get_cached_user_device.invalidate = master_invalidate
keys_to_invalidate = [
("a", "b"),
("c", "d"),
("e", "f"),
("g", "h"),
]
def test_txn(txn: LoggingTransaction) -> None:
self.store._invalidate_cache_and_stream_bulk(
txn,
# This is an arbitrarily chosen cached store function. It was chosen
# because it takes more than one argument. We'll use this later to
# check that the invalidation was actioned over replication.
cache_func=self.store._get_cached_user_device,
key_tuples=keys_to_invalidate,
)
self.get_success(
self.store.db_pool.runInteraction(
"test_invalidate_cache_and_stream_bulk", test_txn
)
)
master_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)
class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase):
def setUp(self) -> None:
super().setUp()
self.store = self.hs.get_datastores().main
def test_bulk_invalidation_replicates(self) -> None:
"""Like test_bulk_invalidation, but also checks the invalidations replicate."""
master_invalidate = Mock()
worker_invalidate = Mock()
self.store._get_cached_user_device.invalidate = master_invalidate
worker = self.make_worker_hs("synapse.app.generic_worker")
worker_ds = worker.get_datastores().main
worker_ds._get_cached_user_device.invalidate = worker_invalidate
keys_to_invalidate = [
("a", "b"),
("c", "d"),
("e", "f"),
("g", "h"),
]
def test_txn(txn: LoggingTransaction) -> None:
self.store._invalidate_cache_and_stream_bulk(
txn,
# This is an arbitrarily chosen cached store function. It was chosen
# because it takes more than one argument. We'll use this later to
# check that the invalidation was actioned over replication.
cache_func=self.store._get_cached_user_device,
key_tuples=keys_to_invalidate,
)
assert self.store._cache_id_gen is not None
initial_token = self.store._cache_id_gen.get_current_token()
self.get_success(
self.database_pool.runInteraction(
"test_invalidate_cache_and_stream_bulk", test_txn
)
)
second_token = self.store._cache_id_gen.get_current_token()
self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate))
self.get_success(
worker.get_replication_data_handler().wait_for_stream_position(
"master", "caches", second_token
)
)
master_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)
worker_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)