Add type hints to `synapse/storage/databases/main` (#11984)
							parent
							
								
									99f6d79fe1
								
							
						
					
					
						commit
						7c82da27aa
					
				|  | @ -0,0 +1 @@ | |||
| Add missing type hints to storage classes. | ||||
							
								
								
									
										3
									
								
								mypy.ini
								
								
								
								
							
							
						
						
									
										3
									
								
								mypy.ini
								
								
								
								
							|  | @ -31,14 +31,11 @@ exclude = (?x) | |||
|    |synapse/storage/databases/main/group_server.py | ||||
|    |synapse/storage/databases/main/metrics.py | ||||
|    |synapse/storage/databases/main/monthly_active_users.py | ||||
|    |synapse/storage/databases/main/presence.py | ||||
|    |synapse/storage/databases/main/purge_events.py | ||||
|    |synapse/storage/databases/main/push_rule.py | ||||
|    |synapse/storage/databases/main/receipts.py | ||||
|    |synapse/storage/databases/main/roommember.py | ||||
|    |synapse/storage/databases/main/search.py | ||||
|    |synapse/storage/databases/main/state.py | ||||
|    |synapse/storage/databases/main/user_directory.py | ||||
|    |synapse/storage/schema/ | ||||
| 
 | ||||
|    |tests/api/test_auth.py | ||||
|  |  | |||
|  | @ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC): | |||
|         Returns: | ||||
|             dict: `user_id` -> `UserPresenceState` | ||||
|         """ | ||||
|         states = { | ||||
|             user_id: self.user_to_current_state.get(user_id, None) | ||||
|             for user_id in user_ids | ||||
|         } | ||||
|         states = {} | ||||
|         missing = [] | ||||
|         for user_id in user_ids: | ||||
|             state = self.user_to_current_state.get(user_id, None) | ||||
|             if state: | ||||
|                 states[user_id] = state | ||||
|             else: | ||||
|                 missing.append(user_id) | ||||
| 
 | ||||
|         missing = [user_id for user_id, state in states.items() if not state] | ||||
|         if missing: | ||||
|             # There are things not in our in memory cache. Lets pull them out of | ||||
|             # the database. | ||||
|             res = await self.store.get_presence_for_users(missing) | ||||
|             states.update(res) | ||||
| 
 | ||||
|             missing = [user_id for user_id, state in states.items() if not state] | ||||
|             if missing: | ||||
|                 new = { | ||||
|                     user_id: UserPresenceState.default(user_id) for user_id in missing | ||||
|                 } | ||||
|                 states.update(new) | ||||
|                 self.user_to_current_state.update(new) | ||||
|             for user_id in missing: | ||||
|                 # if user has no state in database, create the state | ||||
|                 if not res.get(user_id, None): | ||||
|                     new_state = UserPresenceState.default(user_id) | ||||
|                     states[user_id] = new_state | ||||
|                     self.user_to_current_state[user_id] = new_state | ||||
| 
 | ||||
|         return states | ||||
| 
 | ||||
|  |  | |||
|  | @ -12,15 +12,23 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple | ||||
| from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast | ||||
| 
 | ||||
| from synapse.api.presence import PresenceState, UserPresenceState | ||||
| from synapse.replication.tcp.streams import PresenceStream | ||||
| from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause | ||||
| from synapse.storage.database import DatabasePool, LoggingDatabaseConnection | ||||
| from synapse.storage.database import ( | ||||
|     DatabasePool, | ||||
|     LoggingDatabaseConnection, | ||||
|     LoggingTransaction, | ||||
| ) | ||||
| from synapse.storage.engines import PostgresEngine | ||||
| from synapse.storage.types import Connection | ||||
| from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator | ||||
| from synapse.storage.util.id_generators import ( | ||||
|     AbstractStreamIdGenerator, | ||||
|     MultiWriterIdGenerator, | ||||
|     StreamIdGenerator, | ||||
| ) | ||||
| from synapse.util.caches.descriptors import cached, cachedList | ||||
| from synapse.util.caches.stream_change_cache import StreamChangeCache | ||||
| from synapse.util.iterutils import batch_iter | ||||
|  | @ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore): | |||
|         database: DatabasePool, | ||||
|         db_conn: LoggingDatabaseConnection, | ||||
|         hs: "HomeServer", | ||||
|     ): | ||||
|     ) -> None: | ||||
|         super().__init__(database, db_conn, hs) | ||||
| 
 | ||||
|         # Used by `PresenceStore._get_active_presence()` | ||||
|  | @ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
|         database: DatabasePool, | ||||
|         db_conn: LoggingDatabaseConnection, | ||||
|         hs: "HomeServer", | ||||
|     ): | ||||
|     ) -> None: | ||||
|         super().__init__(database, db_conn, hs) | ||||
| 
 | ||||
|         self._instance_name = hs.get_instance_name() | ||||
|         self._presence_id_gen: AbstractStreamIdGenerator | ||||
| 
 | ||||
|         self._can_persist_presence = ( | ||||
|             hs.get_instance_name() in hs.config.worker.writers.presence | ||||
|             self._instance_name in hs.config.worker.writers.presence | ||||
|         ) | ||||
| 
 | ||||
|         if isinstance(database.engine, PostgresEngine): | ||||
|  | @ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
| 
 | ||||
|         return stream_orderings[-1], self._presence_id_gen.get_current_token() | ||||
| 
 | ||||
|     def _update_presence_txn(self, txn, stream_orderings, presence_states): | ||||
|     def _update_presence_txn( | ||||
|         self, txn: LoggingTransaction, stream_orderings, presence_states | ||||
|     ) -> None: | ||||
|         for stream_id, state in zip(stream_orderings, presence_states): | ||||
|             txn.call_after( | ||||
|                 self.presence_stream_cache.entity_has_changed, state.user_id, stream_id | ||||
|  | @ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
|         if last_id == current_id: | ||||
|             return [], current_id, False | ||||
| 
 | ||||
|         def get_all_presence_updates_txn(txn): | ||||
|         def get_all_presence_updates_txn( | ||||
|             txn: LoggingTransaction, | ||||
|         ) -> Tuple[List[Tuple[int, list]], int, bool]: | ||||
|             sql = """ | ||||
|                 SELECT stream_id, user_id, state, last_active_ts, | ||||
|                     last_federation_update_ts, last_user_sync_ts, | ||||
|                     status_msg, | ||||
|                 currently_active | ||||
|                     status_msg, currently_active | ||||
|                 FROM presence_stream | ||||
|                 WHERE ? < stream_id AND stream_id <= ? | ||||
|                 ORDER BY stream_id ASC | ||||
|                 LIMIT ? | ||||
|             """ | ||||
|             txn.execute(sql, (last_id, current_id, limit)) | ||||
|             updates = [(row[0], row[1:]) for row in txn] | ||||
|             updates = cast( | ||||
|                 List[Tuple[int, list]], | ||||
|                 [(row[0], row[1:]) for row in txn], | ||||
|             ) | ||||
| 
 | ||||
|             upper_bound = current_id | ||||
|             limited = False | ||||
|  | @ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
|         ) | ||||
| 
 | ||||
|     @cached() | ||||
|     def _get_presence_for_user(self, user_id): | ||||
|     def _get_presence_for_user(self, user_id: str) -> None: | ||||
|         raise NotImplementedError() | ||||
| 
 | ||||
|     @cachedList( | ||||
|  | @ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
|         list_name="user_ids", | ||||
|         num_args=1, | ||||
|     ) | ||||
|     async def get_presence_for_users(self, user_ids): | ||||
|     async def get_presence_for_users( | ||||
|         self, user_ids: Iterable[str] | ||||
|     ) -> Dict[str, UserPresenceState]: | ||||
|         rows = await self.db_pool.simple_select_many_batch( | ||||
|             table="presence_stream", | ||||
|             column="user_id", | ||||
|  | @ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
|             True if the user should have full presence sent to them, False otherwise. | ||||
|         """ | ||||
| 
 | ||||
|         def _should_user_receive_full_presence_with_token_txn(txn): | ||||
|         def _should_user_receive_full_presence_with_token_txn( | ||||
|             txn: LoggingTransaction, | ||||
|         ) -> bool: | ||||
|             sql = """ | ||||
|                 SELECT 1 FROM users_to_send_full_presence_to | ||||
|                 WHERE user_id = ? | ||||
|  | @ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
|             _should_user_receive_full_presence_with_token_txn, | ||||
|         ) | ||||
| 
 | ||||
|     async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): | ||||
|     async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None: | ||||
|         """Adds to the list of users who should receive a full snapshot of presence | ||||
|         upon their next sync. | ||||
| 
 | ||||
|  | @ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
| 
 | ||||
|         return users_to_state | ||||
| 
 | ||||
|     def get_current_presence_token(self): | ||||
|     def get_current_presence_token(self) -> int: | ||||
|         return self._presence_id_gen.get_current_token() | ||||
| 
 | ||||
|     def _get_active_presence(self, db_conn: Connection): | ||||
|     def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]: | ||||
|         """Fetch non-offline presence from the database so that we can register | ||||
|         the appropriate time outs. | ||||
|         """ | ||||
|  | @ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
| 
 | ||||
|         return [UserPresenceState(**row) for row in rows] | ||||
| 
 | ||||
|     def take_presence_startup_info(self): | ||||
|     def take_presence_startup_info(self) -> List[UserPresenceState]: | ||||
|         active_on_startup = self._presence_on_startup | ||||
|         self._presence_on_startup = None | ||||
|         self._presence_on_startup = [] | ||||
|         return active_on_startup | ||||
| 
 | ||||
|     def process_replication_rows(self, stream_name, instance_name, token, rows): | ||||
|     def process_replication_rows(self, stream_name, instance_name, token, rows) -> None: | ||||
|         if stream_name == PresenceStream.NAME: | ||||
|             self._presence_id_gen.advance(instance_name, token) | ||||
|             for row in rows: | ||||
|  |  | |||
|  | @ -13,9 +13,10 @@ | |||
| # limitations under the License. | ||||
| 
 | ||||
| import logging | ||||
| from typing import Any, List, Set, Tuple | ||||
| from typing import Any, List, Set, Tuple, cast | ||||
| 
 | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.storage.database import LoggingTransaction | ||||
| from synapse.storage.databases.main import CacheInvalidationWorkerStore | ||||
| from synapse.storage.databases.main.state import StateGroupWorkerStore | ||||
| from synapse.types import RoomStreamToken | ||||
|  | @ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): | |||
|         ) | ||||
| 
 | ||||
|     def _purge_history_txn( | ||||
|         self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool | ||||
|         self, | ||||
|         txn: LoggingTransaction, | ||||
|         room_id: str, | ||||
|         token: RoomStreamToken, | ||||
|         delete_local_events: bool, | ||||
|     ) -> Set[int]: | ||||
|         # Tables that should be pruned: | ||||
|         #     event_auth | ||||
|  | @ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): | |||
|         """, | ||||
|             (room_id,), | ||||
|         ) | ||||
|         (min_depth,) = txn.fetchone() | ||||
|         (min_depth,) = cast(Tuple[int], txn.fetchone()) | ||||
| 
 | ||||
|         logger.info("[purge] updating room_depth to %d", min_depth) | ||||
| 
 | ||||
|  | @ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): | |||
|             "purge_room", self._purge_room_txn, room_id | ||||
|         ) | ||||
| 
 | ||||
|     def _purge_room_txn(self, txn, room_id: str) -> List[int]: | ||||
|     def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: | ||||
|         # First we fetch all the state groups that should be deleted, before | ||||
|         # we delete that information. | ||||
|         txn.execute( | ||||
|  |  | |||
|  | @ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
|         database: DatabasePool, | ||||
|         db_conn: LoggingDatabaseConnection, | ||||
|         hs: "HomeServer", | ||||
|     ): | ||||
|     ) -> None: | ||||
|         super().__init__(database, db_conn, hs) | ||||
| 
 | ||||
|         self.server_name = hs.hostname | ||||
|  | @ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
|         processed_event_count = 0 | ||||
| 
 | ||||
|         for room_id, event_count in rooms_to_work_on: | ||||
|             is_in_room = await self.is_host_joined(room_id, self.server_name) | ||||
|             is_in_room = await self.is_host_joined(room_id, self.server_name)  # type: ignore[attr-defined] | ||||
| 
 | ||||
|             if is_in_room: | ||||
|                 users_with_profile = await self.get_users_in_room_with_profiles(room_id) | ||||
|                 users_with_profile = await self.get_users_in_room_with_profiles(room_id)  # type: ignore[attr-defined] | ||||
|                 # Throw away users excluded from the directory. | ||||
|                 users_with_profile = { | ||||
|                     user_id: profile | ||||
|  | @ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
| 
 | ||||
|         for user_id in users_to_work_on: | ||||
|             if await self.should_include_local_user_in_dir(user_id): | ||||
|                 profile = await self.get_profileinfo(get_localpart_from_id(user_id)) | ||||
|                 profile = await self.get_profileinfo(get_localpart_from_id(user_id))  # type: ignore[attr-defined] | ||||
|                 await self.update_profile_in_user_dir( | ||||
|                     user_id, profile.display_name, profile.avatar_url | ||||
|                 ) | ||||
|  | @ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
|         # technically it could be DM-able. In the future, this could potentially | ||||
|         # be configurable per-appservice whether the appservice sender can be | ||||
|         # contacted. | ||||
|         if self.get_app_service_by_user_id(user) is not None: | ||||
|         if self.get_app_service_by_user_id(user) is not None:  # type: ignore[attr-defined] | ||||
|             return False | ||||
| 
 | ||||
|         # We're opting to exclude appservice users (anyone matching the user | ||||
|  | @ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
|         # they could be DM-able. In the future, this could potentially | ||||
|         # be configurable per-appservice whether the appservice users can be | ||||
|         # contacted. | ||||
|         if self.get_if_app_services_interested_in_user(user): | ||||
|         if self.get_if_app_services_interested_in_user(user):  # type: ignore[attr-defined] | ||||
|             # TODO we might want to make this configurable for each app service | ||||
|             return False | ||||
| 
 | ||||
|         # Support users are for diagnostics and should not appear in the user directory. | ||||
|         if await self.is_support_user(user): | ||||
|         if await self.is_support_user(user):  # type: ignore[attr-defined] | ||||
|             return False | ||||
| 
 | ||||
|         # Deactivated users aren't contactable, so should not appear in the user directory. | ||||
|         try: | ||||
|             if await self.get_user_deactivated_status(user): | ||||
|             if await self.get_user_deactivated_status(user):  # type: ignore[attr-defined] | ||||
|                 return False | ||||
|         except StoreError: | ||||
|             # No such user in the users table. No need to do this when calling | ||||
|  | @ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
|             (EventTypes.RoomHistoryVisibility, ""), | ||||
|         ) | ||||
| 
 | ||||
|         current_state_ids = await self.get_filtered_current_state_ids( | ||||
|         current_state_ids = await self.get_filtered_current_state_ids(  # type: ignore[attr-defined] | ||||
|             room_id, StateFilter.from_types(types_to_filter) | ||||
|         ) | ||||
| 
 | ||||
|         join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) | ||||
|         if join_rules_id: | ||||
|             join_rule_ev = await self.get_event(join_rules_id, allow_none=True) | ||||
|             join_rule_ev = await self.get_event(join_rules_id, allow_none=True)  # type: ignore[attr-defined] | ||||
|             if join_rule_ev: | ||||
|                 if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: | ||||
|                     return True | ||||
| 
 | ||||
|         hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) | ||||
|         if hist_vis_id: | ||||
|             hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) | ||||
|             hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)  # type: ignore[attr-defined] | ||||
|             if hist_vis_ev: | ||||
|                 if ( | ||||
|                     hist_vis_ev.content.get("history_visibility") | ||||
|  |  | |||
|  | @ -51,7 +51,7 @@ from synapse.util.stringutils import parse_and_validate_server_name | |||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|     from synapse.appservice.api import ApplicationService | ||||
|     from synapse.storage.databases.main import DataStore | ||||
|     from synapse.storage.databases.main import DataStore, PurgeEventsStore | ||||
| 
 | ||||
| # Define a state map type from type/state_key to T (usually an event ID or | ||||
| # event) | ||||
|  | @ -485,7 +485,7 @@ class RoomStreamToken: | |||
|             ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": | ||||
|     async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": | ||||
|         try: | ||||
|             if string[0] == "s": | ||||
|                 return cls(topological=None, stream=int(string[1:])) | ||||
|  | @ -502,7 +502,7 @@ class RoomStreamToken: | |||
|                     instance_id = int(key) | ||||
|                     pos = int(value) | ||||
| 
 | ||||
|                     instance_name = await store.get_name_from_instance_id(instance_id) | ||||
|                     instance_name = await store.get_name_from_instance_id(instance_id)  # type: ignore[attr-defined] | ||||
|                     instance_map[instance_name] = pos | ||||
| 
 | ||||
|                 return cls( | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Dirk Klimpel
						Dirk Klimpel