Add cache invalidation across workers to module API (#13667)
Signed-off-by: Mathieu Velten <mathieuv@matrix.org>babolivier/msc3881_device_id_tmp
							parent
							
								
									16e1a9d9a7
								
							
						
					
					
						commit
						6bd8763804
					
				|  | @ -0,0 +1 @@ | |||
| Add cache invalidation across workers to module API. | ||||
|  | @ -29,7 +29,7 @@ class SynapsePlugin(Plugin): | |||
|         self, fullname: str | ||||
|     ) -> Optional[Callable[[MethodSigContext], CallableType]]: | ||||
|         if fullname.startswith( | ||||
|             "synapse.util.caches.descriptors._CachedFunction.__call__" | ||||
|             "synapse.util.caches.descriptors.CachedFunction.__call__" | ||||
|         ) or fullname.startswith( | ||||
|             "synapse.util.caches.descriptors._LruCachedFunction.__call__" | ||||
|         ): | ||||
|  | @ -38,7 +38,7 @@ class SynapsePlugin(Plugin): | |||
| 
 | ||||
| 
 | ||||
| def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: | ||||
|     """Fixes the `_CachedFunction.__call__` signature to be correct. | ||||
|     """Fixes the `CachedFunction.__call__` signature to be correct. | ||||
| 
 | ||||
|     It already has *almost* the correct signature, except: | ||||
| 
 | ||||
|  |  | |||
|  | @ -125,7 +125,7 @@ from synapse.types import ( | |||
| ) | ||||
| from synapse.util import Clock | ||||
| from synapse.util.async_helpers import maybe_awaitable | ||||
| from synapse.util.caches.descriptors import cached | ||||
| from synapse.util.caches.descriptors import CachedFunction, cached | ||||
| from synapse.util.frozenutils import freeze | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|  | @ -836,6 +836,37 @@ class ModuleApi: | |||
|             self._store.db_pool.runInteraction(desc, func, *args, **kwargs)  # type: ignore[arg-type] | ||||
|         ) | ||||
| 
 | ||||
|     def register_cached_function(self, cached_func: CachedFunction) -> None: | ||||
|         """Register a cached function that should be invalidated across workers. | ||||
|         Invalidation local to a worker can be done directly using `cached_func.invalidate`, | ||||
|         however invalidation that needs to go to other workers needs to call `invalidate_cache` | ||||
|         on the module API instead. | ||||
| 
 | ||||
|         Args: | ||||
|             cached_function: The cached function that will be registered to receive invalidation | ||||
|             locally and from other workers. | ||||
|         """ | ||||
|         self._store.register_external_cached_function( | ||||
|             f"{cached_func.__module__}.{cached_func.__name__}", cached_func | ||||
|         ) | ||||
| 
 | ||||
|     async def invalidate_cache( | ||||
|         self, cached_func: CachedFunction, keys: Tuple[Any, ...] | ||||
|     ) -> None: | ||||
|         """Invalidate a cache entry of a cached function across workers. The cached function | ||||
|         needs to be registered on all workers first with `register_cached_function`. | ||||
| 
 | ||||
|         Args: | ||||
|             cached_function: The cached function that needs an invalidation | ||||
|             keys: keys of the entry to invalidate, usually matching the arguments of the | ||||
|             cached function. | ||||
|         """ | ||||
|         cached_func.invalidate(keys) | ||||
|         await self._store.send_invalidation_to_replication( | ||||
|             f"{cached_func.__module__}.{cached_func.__name__}", | ||||
|             keys, | ||||
|         ) | ||||
| 
 | ||||
|     async def complete_sso_login_async( | ||||
|         self, | ||||
|         registered_user_id: str, | ||||
|  |  | |||
|  | @ -15,12 +15,13 @@ | |||
| # limitations under the License. | ||||
| import logging | ||||
| from abc import ABCMeta | ||||
| from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union | ||||
| from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union | ||||
| 
 | ||||
| from synapse.storage.database import make_in_list_sql_clause  # noqa: F401; noqa: F401 | ||||
| from synapse.storage.database import DatabasePool, LoggingDatabaseConnection | ||||
| from synapse.types import get_domain_from_id | ||||
| from synapse.util import json_decoder | ||||
| from synapse.util.caches.descriptors import CachedFunction | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|     from synapse.server import HomeServer | ||||
|  | @ -47,6 +48,8 @@ class SQLBaseStore(metaclass=ABCMeta): | |||
|         self.database_engine = database.engine | ||||
|         self.db_pool = database | ||||
| 
 | ||||
|         self.external_cached_functions: Dict[str, CachedFunction] = {} | ||||
| 
 | ||||
|     def process_replication_rows( | ||||
|         self, | ||||
|         stream_name: str, | ||||
|  | @ -95,7 +98,7 @@ class SQLBaseStore(metaclass=ABCMeta): | |||
| 
 | ||||
|     def _attempt_to_invalidate_cache( | ||||
|         self, cache_name: str, key: Optional[Collection[Any]] | ||||
|     ) -> None: | ||||
|     ) -> bool: | ||||
|         """Attempts to invalidate the cache of the given name, ignoring if the | ||||
|         cache doesn't exist. Mainly used for invalidating caches on workers, | ||||
|         where they may not have the cache. | ||||
|  | @ -113,9 +116,12 @@ class SQLBaseStore(metaclass=ABCMeta): | |||
|         try: | ||||
|             cache = getattr(self, cache_name) | ||||
|         except AttributeError: | ||||
|             # We probably haven't pulled in the cache in this worker, | ||||
|             # which is fine. | ||||
|             return | ||||
|             # Check if an externally defined module cache has been registered | ||||
|             cache = self.external_cached_functions.get(cache_name) | ||||
|             if not cache: | ||||
|                 # We probably haven't pulled in the cache in this worker, | ||||
|                 # which is fine. | ||||
|                 return False | ||||
| 
 | ||||
|         if key is None: | ||||
|             cache.invalidate_all() | ||||
|  | @ -125,6 +131,13 @@ class SQLBaseStore(metaclass=ABCMeta): | |||
|             invalidate_method = getattr(cache, "invalidate_local", cache.invalidate) | ||||
|             invalidate_method(tuple(key)) | ||||
| 
 | ||||
|         return True | ||||
| 
 | ||||
|     def register_external_cached_function( | ||||
|         self, cache_name: str, func: CachedFunction | ||||
|     ) -> None: | ||||
|         self.external_cached_functions[cache_name] = func | ||||
| 
 | ||||
| 
 | ||||
| def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: | ||||
|     """ | ||||
|  |  | |||
|  | @ -33,7 +33,7 @@ from synapse.storage.database import ( | |||
| ) | ||||
| from synapse.storage.engines import PostgresEngine | ||||
| from synapse.storage.util.id_generators import MultiWriterIdGenerator | ||||
| from synapse.util.caches.descriptors import _CachedFunction | ||||
| from synapse.util.caches.descriptors import CachedFunction | ||||
| from synapse.util.iterutils import batch_iter | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|  | @ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
|             return | ||||
| 
 | ||||
|         cache_func.invalidate(keys) | ||||
|         await self.db_pool.runInteraction( | ||||
|             "invalidate_cache_and_stream", | ||||
|             self._send_invalidation_to_replication, | ||||
|         await self.send_invalidation_to_replication( | ||||
|             cache_func.__name__, | ||||
|             keys, | ||||
|         ) | ||||
|  | @ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
|     def _invalidate_cache_and_stream( | ||||
|         self, | ||||
|         txn: LoggingTransaction, | ||||
|         cache_func: _CachedFunction, | ||||
|         cache_func: CachedFunction, | ||||
|         keys: Tuple[Any, ...], | ||||
|     ) -> None: | ||||
|         """Invalidates the cache and adds it to the cache stream so slaves | ||||
|  | @ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
|         self._send_invalidation_to_replication(txn, cache_func.__name__, keys) | ||||
| 
 | ||||
|     def _invalidate_all_cache_and_stream( | ||||
|         self, txn: LoggingTransaction, cache_func: _CachedFunction | ||||
|         self, txn: LoggingTransaction, cache_func: CachedFunction | ||||
|     ) -> None: | ||||
|         """Invalidates the entire cache and adds it to the cache stream so slaves | ||||
|         will know to invalidate their caches. | ||||
|  | @ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
|                 txn, CURRENT_STATE_CACHE_NAME, [room_id] | ||||
|             ) | ||||
| 
 | ||||
|     async def send_invalidation_to_replication( | ||||
|         self, cache_name: str, keys: Optional[Collection[Any]] | ||||
|     ) -> None: | ||||
|         await self.db_pool.runInteraction( | ||||
|             "send_invalidation_to_replication", | ||||
|             self._send_invalidation_to_replication, | ||||
|             cache_name, | ||||
|             keys, | ||||
|         ) | ||||
| 
 | ||||
|     def _send_invalidation_to_replication( | ||||
|         self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] | ||||
|     ) -> None: | ||||
|  |  | |||
|  | @ -53,7 +53,7 @@ CacheKey = Union[Tuple, Any] | |||
| F = TypeVar("F", bound=Callable[..., Any]) | ||||
| 
 | ||||
| 
 | ||||
| class _CachedFunction(Generic[F]): | ||||
| class CachedFunction(Generic[F]): | ||||
|     invalidate: Any = None | ||||
|     invalidate_all: Any = None | ||||
|     prefill: Any = None | ||||
|  | @ -242,7 +242,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): | |||
| 
 | ||||
|             return ret2 | ||||
| 
 | ||||
|         wrapped = cast(_CachedFunction, _wrapped) | ||||
|         wrapped = cast(CachedFunction, _wrapped) | ||||
|         wrapped.cache = cache | ||||
|         obj.__dict__[self.name] = wrapped | ||||
| 
 | ||||
|  | @ -363,7 +363,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): | |||
| 
 | ||||
|             return make_deferred_yieldable(ret) | ||||
| 
 | ||||
|         wrapped = cast(_CachedFunction, _wrapped) | ||||
|         wrapped = cast(CachedFunction, _wrapped) | ||||
| 
 | ||||
|         if self.num_args == 1: | ||||
|             assert not self.tree | ||||
|  | @ -572,7 +572,7 @@ def cached( | |||
|     iterable: bool = False, | ||||
|     prune_unread_entries: bool = True, | ||||
|     name: Optional[str] = None, | ||||
| ) -> Callable[[F], _CachedFunction[F]]: | ||||
| ) -> Callable[[F], CachedFunction[F]]: | ||||
|     func = lambda orig: DeferredCacheDescriptor( | ||||
|         orig, | ||||
|         max_entries=max_entries, | ||||
|  | @ -585,7 +585,7 @@ def cached( | |||
|         name=name, | ||||
|     ) | ||||
| 
 | ||||
|     return cast(Callable[[F], _CachedFunction[F]], func) | ||||
|     return cast(Callable[[F], CachedFunction[F]], func) | ||||
| 
 | ||||
| 
 | ||||
| def cachedList( | ||||
|  | @ -594,7 +594,7 @@ def cachedList( | |||
|     list_name: str, | ||||
|     num_args: Optional[int] = None, | ||||
|     name: Optional[str] = None, | ||||
| ) -> Callable[[F], _CachedFunction[F]]: | ||||
| ) -> Callable[[F], CachedFunction[F]]: | ||||
|     """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. | ||||
| 
 | ||||
|     Used to do batch lookups for an already created cache. One of the arguments | ||||
|  | @ -631,7 +631,7 @@ def cachedList( | |||
|         name=name, | ||||
|     ) | ||||
| 
 | ||||
|     return cast(Callable[[F], _CachedFunction[F]], func) | ||||
|     return cast(Callable[[F], CachedFunction[F]], func) | ||||
| 
 | ||||
| 
 | ||||
| def _get_cache_key_builder( | ||||
|  |  | |||
|  | @ -0,0 +1,79 @@ | |||
| # Copyright 2022 The Matrix.org Foundation C.I.C. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import logging | ||||
| 
 | ||||
| import synapse | ||||
| from synapse.module_api import cached | ||||
| 
 | ||||
| from tests.replication._base import BaseMultiWorkerStreamTestCase | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| FIRST_VALUE = "one" | ||||
| SECOND_VALUE = "two" | ||||
| 
 | ||||
| KEY = "mykey" | ||||
| 
 | ||||
| 
 | ||||
| class TestCache: | ||||
|     current_value = FIRST_VALUE | ||||
| 
 | ||||
|     @cached() | ||||
|     async def cached_function(self, user_id: str) -> str: | ||||
|         return self.current_value | ||||
| 
 | ||||
| 
 | ||||
| class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase): | ||||
|     servlets = [ | ||||
|         synapse.rest.admin.register_servlets, | ||||
|     ] | ||||
| 
 | ||||
|     def test_module_cache_full_invalidation(self): | ||||
|         main_cache = TestCache() | ||||
|         self.hs.get_module_api().register_cached_function(main_cache.cached_function) | ||||
| 
 | ||||
|         worker_hs = self.make_worker_hs("synapse.app.generic_worker") | ||||
| 
 | ||||
|         worker_cache = TestCache() | ||||
|         worker_hs.get_module_api().register_cached_function( | ||||
|             worker_cache.cached_function | ||||
|         ) | ||||
| 
 | ||||
|         self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) | ||||
|         self.assertEqual( | ||||
|             FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)) | ||||
|         ) | ||||
| 
 | ||||
|         main_cache.current_value = SECOND_VALUE | ||||
|         worker_cache.current_value = SECOND_VALUE | ||||
|         # No invalidation yet, should return the cached value on both the main process and the worker | ||||
|         self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) | ||||
|         self.assertEqual( | ||||
|             FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)) | ||||
|         ) | ||||
| 
 | ||||
|         # Full invalidation on the main process, should be replicated on the worker that | ||||
|         # should returned the updated value too | ||||
|         self.get_success( | ||||
|             self.hs.get_module_api().invalidate_cache( | ||||
|                 main_cache.cached_function, (KEY,) | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         self.assertEqual( | ||||
|             SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)) | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY)) | ||||
|         ) | ||||
		Loading…
	
		Reference in New Issue
	
	 Mathieu Velten
						Mathieu Velten