From c486fa5fd9082643e40a55ffa59d902aa6db4c2b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 10:37:04 -0400 Subject: [PATCH] Add some missing type hints to cache datastore. (#12216) --- changelog.d/12216.misc | 1 + synapse/storage/databases/main/cache.py | 57 ++++++++++++++++--------- 2 files changed, 37 insertions(+), 21 deletions(-) create mode 100644 changelog.d/12216.misc diff --git a/changelog.d/12216.misc b/changelog.d/12216.misc new file mode 100644 index 0000000000..dc398ac1e0 --- /dev/null +++ b/changelog.d/12216.misc @@ -0,0 +1 @@ +Add missing type hints for cache storage. diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index d6a2df1afe..2d7511d613 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamCurrentStateRow, EventsStreamEventRow, + EventsStreamRow, ) from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -31,6 +32,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine +from synapse.util.caches.descriptors import _CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_updated_caches_txn(txn): + def get_all_updated_caches_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: # We purposefully don't bound by the current token, as we want to # send across cache invalidations as quickly as possible. Cache # invalidations are idempotent, so duplicates are fine. @@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "get_all_updated_caches", get_all_updated_caches_txn ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == EventsStream.NAME: for row in rows: self._process_event_stream_row(token, row) @@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) - def _process_event_stream_row(self, token, row): + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data if row.type == EventsStreamEventRow.TypeId: + assert isinstance(data, EventsStreamEventRow) self._invalidate_caches_for_event( token, data.event_id, @@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) + assert isinstance(data, EventsStreamCurrentStateRow) + self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) if data.type == EventTypes.Member: self.get_rooms_for_user_with_stream_ordering.invalidate( @@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore): def _invalidate_caches_for_event( self, - stream_ordering, - event_id, - room_id, - etype, - state_key, - redacts, - relates_to, - backfilled, - ): + stream_ordering: int, + event_id: str, + room_id: str, + etype: str, + state_key: Optional[str], + redacts: Optional[str], + relates_to: Optional[str], + backfilled: bool, + ) -> None: self._invalidate_get_event_cache(event_id) self.have_seen_event.invalidate((room_id, event_id)) @@ -207,7 +213,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_thread_summary.invalidate((relates_to,)) self.get_thread_participated.invalidate((relates_to,)) - async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + async def invalidate_cache_and_stream( + self, cache_name: str, keys: Tuple[Any, ...] + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -227,7 +235,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore): keys, ) - def _invalidate_cache_and_stream(self, txn, cache_func, keys): + def _invalidate_cache_and_stream( + self, + txn: LoggingTransaction, + cache_func: _CachedFunction, + keys: Tuple[Any, ...], + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -238,7 +251,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn.call_after(cache_func.invalidate, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - def _invalidate_all_cache_and_stream(self, txn, cache_func): + def _invalidate_all_cache_and_stream( + self, txn: LoggingTransaction, cache_func: _CachedFunction + ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. """ @@ -279,8 +294,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) def _send_invalidation_to_replication( - self, txn, cache_name: str, keys: Optional[Iterable[Any]] - ): + self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] + ) -> None: """Notifies replication that given cache has been invalidated. Note that this does *not* invalidate the cache locally. @@ -315,7 +330,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, - "invalidation_ts": self.clock.time_msec(), + "invalidation_ts": self._clock.time_msec(), }, )