diff --git a/changelog.d/11984.misc b/changelog.d/11984.misc new file mode 100644 index 0000000000..8e405b9226 --- /dev/null +++ b/changelog.d/11984.misc @@ -0,0 +1 @@ +Add missing type hints to storage classes. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 63848d664c..610660b9b7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -31,14 +31,11 @@ exclude = (?x) |synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/monthly_active_users.py - |synapse/storage/databases/main/presence.py - |synapse/storage/databases/main/purge_events.py |synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/search.py |synapse/storage/databases/main/state.py - |synapse/storage/databases/main/user_directory.py |synapse/storage/schema/ |tests/api/test_auth.py diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 067c43ae47..b223b72623 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC): Returns: dict: `user_id` -> `UserPresenceState` """ - states = { - user_id: self.user_to_current_state.get(user_id, None) - for user_id in user_ids - } + states = {} + missing = [] + for user_id in user_ids: + state = self.user_to_current_state.get(user_id, None) + if state: + states[user_id] = state + else: + missing.append(user_id) - missing = [user_id for user_id, state in states.items() if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = await self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in states.items() if not state] - if missing: - new = { - user_id: UserPresenceState.default(user_id) for user_id in missing - } - states.update(new) - self.user_to_current_state.update(new) + for user_id in missing: + # if user has no state in database, create the state + if not res.get(user_id, None): + new_state = UserPresenceState.default(user_id) + states[user_id] = new_state + self.user_to_current_state[user_id] = new_state return states diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 4f05811a77..d3c4611686 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -12,15 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.iterutils import batch_iter @@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) # Used by `PresenceStore._get_active_presence()` @@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) + self._instance_name = hs.get_instance_name() + self._presence_id_gen: AbstractStreamIdGenerator + self._can_persist_presence = ( - hs.get_instance_name() in hs.config.worker.writers.presence + self._instance_name in hs.config.worker.writers.presence ) if isinstance(database.engine, PostgresEngine): @@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): return stream_orderings[-1], self._presence_id_gen.get_current_token() - def _update_presence_txn(self, txn, stream_orderings, presence_states): + def _update_presence_txn( + self, txn: LoggingTransaction, stream_orderings, presence_states + ) -> None: for stream_id, state in zip(stream_orderings, presence_states): txn.call_after( self.presence_stream_cache.entity_has_changed, state.user_id, stream_id @@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore): if last_id == current_id: return [], current_id, False - def get_all_presence_updates_txn(txn): + def get_all_presence_updates_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, list]], int, bool]: sql = """ SELECT stream_id, user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, - status_msg, - currently_active + status_msg, currently_active FROM presence_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(row[0], row[1:]) for row in txn] + updates = cast( + List[Tuple[int, list]], + [(row[0], row[1:]) for row in txn], + ) upper_bound = current_id limited = False @@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): ) @cached() - def _get_presence_for_user(self, user_id): + def _get_presence_for_user(self, user_id: str) -> None: raise NotImplementedError() @cachedList( @@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): list_name="user_ids", num_args=1, ) - async def get_presence_for_users(self, user_ids): + async def get_presence_for_users( + self, user_ids: Iterable[str] + ) -> Dict[str, UserPresenceState]: rows = await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", @@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): True if the user should have full presence sent to them, False otherwise. """ - def _should_user_receive_full_presence_with_token_txn(txn): + def _should_user_receive_full_presence_with_token_txn( + txn: LoggingTransaction, + ) -> bool: sql = """ SELECT 1 FROM users_to_send_full_presence_to WHERE user_id = ? @@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): _should_user_receive_full_presence_with_token_txn, ) - async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): + async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None: """Adds to the list of users who should receive a full snapshot of presence upon their next sync. @@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore): return users_to_state - def get_current_presence_token(self): + def get_current_presence_token(self) -> int: return self._presence_id_gen.get_current_token() - def _get_active_presence(self, db_conn: Connection): + def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]: """Fetch non-offline presence from the database so that we can register the appropriate time outs. """ @@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore): return [UserPresenceState(**row) for row in rows] - def take_presence_startup_info(self): + def take_presence_startup_info(self) -> List[UserPresenceState]: active_on_startup = self._presence_on_startup - self._presence_on_startup = None + self._presence_on_startup = [] return active_on_startup - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows) -> None: if stream_name == PresenceStream.NAME: self._presence_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index e87a8fb85d..2e3818e432 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -13,9 +13,10 @@ # limitations under the License. import logging -from typing import Any, List, Set, Tuple +from typing import Any, List, Set, Tuple, cast from synapse.api.errors import SynapseError +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.types import RoomStreamToken @@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): ) def _purge_history_txn( - self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool + self, + txn: LoggingTransaction, + room_id: str, + token: RoomStreamToken, + delete_local_events: bool, ) -> Set[int]: # Tables that should be pruned: # event_auth @@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): """, (room_id,), ) - (min_depth,) = txn.fetchone() + (min_depth,) = cast(Tuple[int], txn.fetchone()) logger.info("[purge] updating room_depth to %d", min_depth) @@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "purge_room", self._purge_room_txn, room_id ) - def _purge_room_txn(self, txn, room_id: str) -> List[int]: + def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: # First we fetch all the state groups that should be deleted, before # we delete that information. txn.execute( diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f7c778bdf2..e7fddd2426 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): processed_event_count = 0 for room_id, event_count in rooms_to_work_on: - is_in_room = await self.is_host_joined(room_id, self.server_name) + is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined] if is_in_room: - users_with_profile = await self.get_users_in_room_with_profiles(room_id) + users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined] # Throw away users excluded from the directory. users_with_profile = { user_id: profile @@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): for user_id in users_to_work_on: if await self.should_include_local_user_in_dir(user_id): - profile = await self.get_profileinfo(get_localpart_from_id(user_id)) + profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined] await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) @@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # technically it could be DM-able. In the future, this could potentially # be configurable per-appservice whether the appservice sender can be # contacted. - if self.get_app_service_by_user_id(user) is not None: + if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined] return False # We're opting to exclude appservice users (anyone matching the user @@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # they could be DM-able. In the future, this could potentially # be configurable per-appservice whether the appservice users can be # contacted. - if self.get_if_app_services_interested_in_user(user): + if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined] # TODO we might want to make this configurable for each app service return False # Support users are for diagnostics and should not appear in the user directory. - if await self.is_support_user(user): + if await self.is_support_user(user): # type: ignore[attr-defined] return False # Deactivated users aren't contactable, so should not appear in the user directory. try: - if await self.get_user_deactivated_status(user): + if await self.get_user_deactivated_status(user): # type: ignore[attr-defined] return False except StoreError: # No such user in the users table. No need to do this when calling @@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = await self.get_filtered_current_state_ids( + current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined] room_id, StateFilter.from_types(types_to_filter) ) join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) if join_rules_id: - join_rule_ev = await self.get_event(join_rules_id, allow_none=True) + join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined] if join_rule_ev: if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: return True hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) if hist_vis_id: - hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) + hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined] if hist_vis_ev: if ( hist_vis_ev.content.get("history_visibility") diff --git a/synapse/types.py b/synapse/types.py index f89fb216a6..53be3583a0 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -51,7 +51,7 @@ from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService - from synapse.storage.databases.main import DataStore + from synapse.storage.databases.main import DataStore, PurgeEventsStore # Define a state map type from type/state_key to T (usually an event ID or # event) @@ -485,7 +485,7 @@ class RoomStreamToken: ) @classmethod - async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": + async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -502,7 +502,7 @@ class RoomStreamToken: instance_id = int(key) pos = int(value) - instance_name = await store.get_name_from_instance_id(instance_id) + instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] instance_map[instance_name] = pos return cls(