Add type hints to `synapse/storage/databases/main/room.py` (#11575)
							parent
							
								
									f901f8b70e
								
							
						
					
					
						commit
						c7fe32edb4
					
				|  | @ -0,0 +1 @@ | |||
| Add missing type hints to storage classes. | ||||
							
								
								
									
										4
									
								
								mypy.ini
								
								
								
								
							
							
						
						
									
										4
									
								
								mypy.ini
								
								
								
								
							|  | @ -37,7 +37,6 @@ exclude = (?x) | |||
|    |synapse/storage/databases/main/purge_events.py | ||||
|    |synapse/storage/databases/main/push_rule.py | ||||
|    |synapse/storage/databases/main/receipts.py | ||||
|    |synapse/storage/databases/main/room.py | ||||
|    |synapse/storage/databases/main/roommember.py | ||||
|    |synapse/storage/databases/main/search.py | ||||
|    |synapse/storage/databases/main/state.py | ||||
|  | @ -205,6 +204,9 @@ disallow_untyped_defs = True | |||
| [mypy-synapse.storage.databases.main.events_worker] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.storage.databases.main.room] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.storage.databases.main.room_batch] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
|  |  | |||
|  | @ -1020,7 +1020,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||
|         # Add new room to the room directory if the old room was there | ||||
|         # Remove old room from the room directory | ||||
|         old_room = await self.store.get_room(old_room_id) | ||||
|         if old_room and old_room["is_public"]: | ||||
|         if old_room is not None and old_room["is_public"]: | ||||
|             await self.store.set_room_is_public(old_room_id, False) | ||||
|             await self.store.set_room_is_public(room_id, True) | ||||
| 
 | ||||
|  | @ -1031,7 +1031,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||
|         local_group_ids = await self.store.get_local_groups_for_room(old_room_id) | ||||
|         for group_id in local_group_ids: | ||||
|             # Add new the new room to those groups | ||||
|             await self.store.add_room_to_group(group_id, room_id, old_room["is_public"]) | ||||
|             await self.store.add_room_to_group( | ||||
|                 group_id, room_id, old_room is not None and old_room["is_public"] | ||||
|             ) | ||||
| 
 | ||||
|             # Remove the old room from those groups | ||||
|             await self.store.remove_room_from_group(group_id, old_room_id) | ||||
|  |  | |||
|  | @ -149,7 +149,6 @@ class DataStore( | |||
|             ], | ||||
|         ) | ||||
| 
 | ||||
|         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") | ||||
|         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") | ||||
|         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") | ||||
|         self._group_updates_id_gen = StreamIdGenerator( | ||||
|  |  | |||
|  | @ -17,7 +17,7 @@ import collections | |||
| import logging | ||||
| from abc import abstractmethod | ||||
| from enum import Enum | ||||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple | ||||
| from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast | ||||
| 
 | ||||
| from synapse.api.constants import EventContentFields, EventTypes, JoinRules | ||||
| from synapse.api.errors import StoreError | ||||
|  | @ -29,8 +29,9 @@ from synapse.storage.database import ( | |||
|     LoggingDatabaseConnection, | ||||
|     LoggingTransaction, | ||||
| ) | ||||
| from synapse.storage.databases.main.search import SearchStore | ||||
| from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | ||||
| from synapse.storage.types import Cursor | ||||
| from synapse.storage.util.id_generators import IdGenerator | ||||
| from synapse.types import JsonDict, ThirdPartyInstanceID | ||||
| from synapse.util import json_encoder | ||||
| from synapse.util.caches.descriptors import cached | ||||
|  | @ -75,7 +76,7 @@ class RoomSortOrder(Enum): | |||
|     STATE_EVENTS = "state_events" | ||||
| 
 | ||||
| 
 | ||||
| class RoomWorkerStore(SQLBaseStore): | ||||
| class RoomWorkerStore(CacheInvalidationWorkerStore): | ||||
|     def __init__( | ||||
|         self, | ||||
|         database: DatabasePool, | ||||
|  | @ -92,7 +93,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
|         room_creator_user_id: str, | ||||
|         is_public: bool, | ||||
|         room_version: RoomVersion, | ||||
|     ): | ||||
|     ) -> None: | ||||
|         """Stores a room. | ||||
| 
 | ||||
|         Args: | ||||
|  | @ -120,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
|             logger.error("store_room with room_id=%s failed: %s", room_id, e) | ||||
|             raise StoreError(500, "Problem creating room.") | ||||
| 
 | ||||
|     async def get_room(self, room_id: str) -> dict: | ||||
|     async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]: | ||||
|         """Retrieve a room. | ||||
| 
 | ||||
|         Args: | ||||
|  | @ -145,7 +146,9 @@ class RoomWorkerStore(SQLBaseStore): | |||
|             A dict containing the room information, or None if the room is unknown. | ||||
|         """ | ||||
| 
 | ||||
|         def get_room_with_stats_txn(txn, room_id): | ||||
|         def get_room_with_stats_txn( | ||||
|             txn: LoggingTransaction, room_id: str | ||||
|         ) -> Optional[Dict[str, Any]]: | ||||
|             sql = """ | ||||
|                 SELECT room_id, state.name, state.canonical_alias, curr.joined_members, | ||||
|                   curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, | ||||
|  | @ -194,7 +197,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
|             ignore_non_federatable: If true filters out non-federatable rooms | ||||
|         """ | ||||
| 
 | ||||
|         def _count_public_rooms_txn(txn): | ||||
|         def _count_public_rooms_txn(txn: LoggingTransaction) -> int: | ||||
|             query_args = [] | ||||
| 
 | ||||
|             if network_tuple: | ||||
|  | @ -235,7 +238,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
|             } | ||||
| 
 | ||||
|             txn.execute(sql, query_args) | ||||
|             return txn.fetchone()[0] | ||||
|             return cast(Tuple[int], txn.fetchone())[0] | ||||
| 
 | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "count_public_rooms", _count_public_rooms_txn | ||||
|  | @ -244,11 +247,11 @@ class RoomWorkerStore(SQLBaseStore): | |||
|     async def get_room_count(self) -> int: | ||||
|         """Retrieve the total number of rooms.""" | ||||
| 
 | ||||
|         def f(txn): | ||||
|         def f(txn: LoggingTransaction) -> int: | ||||
|             sql = "SELECT count(*)  FROM rooms" | ||||
|             txn.execute(sql) | ||||
|             row = txn.fetchone() | ||||
|             return row[0] or 0 | ||||
|             row = cast(Tuple[int], txn.fetchone()) | ||||
|             return row[0] | ||||
| 
 | ||||
|         return await self.db_pool.runInteraction("get_rooms", f) | ||||
| 
 | ||||
|  | @ -260,7 +263,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
|         bounds: Optional[Tuple[int, str]], | ||||
|         forwards: bool, | ||||
|         ignore_non_federatable: bool = False, | ||||
|     ): | ||||
|     ) -> List[Dict[str, Any]]: | ||||
|         """Gets the largest public rooms (where largest is in terms of joined | ||||
|         members, as tracked in the statistics table). | ||||
| 
 | ||||
|  | @ -381,7 +384,9 @@ class RoomWorkerStore(SQLBaseStore): | |||
|                 LIMIT ? | ||||
|             """ | ||||
| 
 | ||||
|         def _get_largest_public_rooms_txn(txn): | ||||
|         def _get_largest_public_rooms_txn( | ||||
|             txn: LoggingTransaction, | ||||
|         ) -> List[Dict[str, Any]]: | ||||
|             txn.execute(sql, query_args) | ||||
| 
 | ||||
|             results = self.db_pool.cursor_to_dict(txn) | ||||
|  | @ -444,7 +449,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
|         """ | ||||
|         # Filter room names by a string | ||||
|         where_statement = "" | ||||
|         search_pattern = [] | ||||
|         search_pattern: List[object] = [] | ||||
|         if search_term: | ||||
|             where_statement = """ | ||||
|                 WHERE LOWER(state.name) LIKE ? | ||||
|  | @ -552,7 +557,9 @@ class RoomWorkerStore(SQLBaseStore): | |||
|             where_statement, | ||||
|         ) | ||||
| 
 | ||||
|         def _get_rooms_paginate_txn(txn): | ||||
|         def _get_rooms_paginate_txn( | ||||
|             txn: LoggingTransaction, | ||||
|         ) -> Tuple[List[Dict[str, Any]], int]: | ||||
|             # Add the search term into the WHERE clause | ||||
|             # and execute the data query | ||||
|             txn.execute(info_sql, search_pattern + [limit, start]) | ||||
|  | @ -584,7 +591,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
|             # Add the search term into the WHERE clause if present | ||||
|             txn.execute(count_sql, search_pattern) | ||||
| 
 | ||||
|             room_count = txn.fetchone() | ||||
|             room_count = cast(Tuple[int], txn.fetchone()) | ||||
|             return rooms, room_count[0] | ||||
| 
 | ||||
|         return await self.db_pool.runInteraction( | ||||
|  | @ -629,7 +636,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
|             burst_count: How many actions that can be performed before being limited. | ||||
|         """ | ||||
| 
 | ||||
|         def set_ratelimit_txn(txn): | ||||
|         def set_ratelimit_txn(txn: LoggingTransaction) -> None: | ||||
|             self.db_pool.simple_upsert_txn( | ||||
|                 txn, | ||||
|                 table="ratelimit_override", | ||||
|  | @ -652,7 +659,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
|             user_id: user ID of the user | ||||
|         """ | ||||
| 
 | ||||
|         def delete_ratelimit_txn(txn): | ||||
|         def delete_ratelimit_txn(txn: LoggingTransaction) -> None: | ||||
|             row = self.db_pool.simple_select_one_txn( | ||||
|                 txn, | ||||
|                 table="ratelimit_override", | ||||
|  | @ -676,7 +683,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
|         await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn) | ||||
| 
 | ||||
|     @cached() | ||||
|     async def get_retention_policy_for_room(self, room_id): | ||||
|     async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]: | ||||
|         """Get the retention policy for a given room. | ||||
| 
 | ||||
|         If no retention policy has been found for this room, returns a policy defined | ||||
|  | @ -685,13 +692,15 @@ class RoomWorkerStore(SQLBaseStore): | |||
|         configuration). | ||||
| 
 | ||||
|         Args: | ||||
|             room_id (str): The ID of the room to get the retention policy of. | ||||
|             room_id: The ID of the room to get the retention policy of. | ||||
| 
 | ||||
|         Returns: | ||||
|             dict[int, int]: "min_lifetime" and "max_lifetime" for this room. | ||||
|             A dict containing "min_lifetime" and "max_lifetime" for this room. | ||||
|         """ | ||||
| 
 | ||||
|         def get_retention_policy_for_room_txn(txn): | ||||
|         def get_retention_policy_for_room_txn( | ||||
|             txn: LoggingTransaction, | ||||
|         ) -> List[Dict[str, Optional[int]]]: | ||||
|             txn.execute( | ||||
|                 """ | ||||
|                 SELECT min_lifetime, max_lifetime FROM room_retention | ||||
|  | @ -716,19 +725,23 @@ class RoomWorkerStore(SQLBaseStore): | |||
|                 "max_lifetime": self.config.retention.retention_default_max_lifetime, | ||||
|             } | ||||
| 
 | ||||
|         row = ret[0] | ||||
|         min_lifetime = ret[0]["min_lifetime"] | ||||
|         max_lifetime = ret[0]["max_lifetime"] | ||||
| 
 | ||||
|         # If one of the room's policy's attributes isn't defined, use the matching | ||||
|         # attribute from the default policy. | ||||
|         # The default values will be None if no default policy has been defined, or if one | ||||
|         # of the attributes is missing from the default policy. | ||||
|         if row["min_lifetime"] is None: | ||||
|             row["min_lifetime"] = self.config.retention.retention_default_min_lifetime | ||||
|         if min_lifetime is None: | ||||
|             min_lifetime = self.config.retention.retention_default_min_lifetime | ||||
| 
 | ||||
|         if row["max_lifetime"] is None: | ||||
|             row["max_lifetime"] = self.config.retention.retention_default_max_lifetime | ||||
|         if max_lifetime is None: | ||||
|             max_lifetime = self.config.retention.retention_default_max_lifetime | ||||
| 
 | ||||
|         return row | ||||
|         return { | ||||
|             "min_lifetime": min_lifetime, | ||||
|             "max_lifetime": max_lifetime, | ||||
|         } | ||||
| 
 | ||||
|     async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: | ||||
|         """Retrieves all the local and remote media MXC URIs in a given room | ||||
|  | @ -740,7 +753,9 @@ class RoomWorkerStore(SQLBaseStore): | |||
|             The local and remote media as a lists of the media IDs. | ||||
|         """ | ||||
| 
 | ||||
|         def _get_media_mxcs_in_room_txn(txn): | ||||
|         def _get_media_mxcs_in_room_txn( | ||||
|             txn: LoggingTransaction, | ||||
|         ) -> Tuple[List[str], List[str]]: | ||||
|             local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) | ||||
|             local_media_mxcs = [] | ||||
|             remote_media_mxcs = [] | ||||
|  | @ -766,7 +781,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|         logger.info("Quarantining media in room: %s", room_id) | ||||
| 
 | ||||
|         def _quarantine_media_in_room_txn(txn): | ||||
|         def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int: | ||||
|             local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) | ||||
|             return self._quarantine_media_txn( | ||||
|                 txn, local_mxcs, remote_mxcs, quarantined_by | ||||
|  | @ -776,13 +791,11 @@ class RoomWorkerStore(SQLBaseStore): | |||
|             "quarantine_media_in_room", _quarantine_media_in_room_txn | ||||
|         ) | ||||
| 
 | ||||
|     def _get_media_mxcs_in_room_txn(self, txn, room_id): | ||||
|     def _get_media_mxcs_in_room_txn( | ||||
|         self, txn: LoggingTransaction, room_id: str | ||||
|     ) -> Tuple[List[str], List[Tuple[str, str]]]: | ||||
|         """Retrieves all the local and remote media MXC URIs in a given room | ||||
| 
 | ||||
|         Args: | ||||
|             txn (cursor) | ||||
|             room_id (str) | ||||
| 
 | ||||
|         Returns: | ||||
|             The local and remote media as a lists of tuples where the key is | ||||
|             the hostname and the value is the media ID. | ||||
|  | @ -850,7 +863,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
|         logger.info("Quarantining media: %s/%s", server_name, media_id) | ||||
|         is_local = server_name == self.config.server.server_name | ||||
| 
 | ||||
|         def _quarantine_media_by_id_txn(txn): | ||||
|         def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int: | ||||
|             local_mxcs = [media_id] if is_local else [] | ||||
|             remote_mxcs = [(server_name, media_id)] if not is_local else [] | ||||
| 
 | ||||
|  | @ -872,7 +885,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
|             quarantined_by: The ID of the user who made the quarantine request | ||||
|         """ | ||||
| 
 | ||||
|         def _quarantine_media_by_user_txn(txn): | ||||
|         def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int: | ||||
|             local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) | ||||
|             return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) | ||||
| 
 | ||||
|  | @ -880,7 +893,9 @@ class RoomWorkerStore(SQLBaseStore): | |||
|             "quarantine_media_by_user", _quarantine_media_by_user_txn | ||||
|         ) | ||||
| 
 | ||||
|     def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True): | ||||
|     def _get_media_ids_by_user_txn( | ||||
|         self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True | ||||
|     ) -> List[str]: | ||||
|         """Retrieves local media IDs by a given user | ||||
| 
 | ||||
|         Args: | ||||
|  | @ -909,7 +924,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|     def _quarantine_media_txn( | ||||
|         self, | ||||
|         txn, | ||||
|         txn: LoggingTransaction, | ||||
|         local_mxcs: List[str], | ||||
|         remote_mxcs: List[Tuple[str, str]], | ||||
|         quarantined_by: Optional[str], | ||||
|  | @ -937,12 +952,15 @@ class RoomWorkerStore(SQLBaseStore): | |||
|         # set quarantine | ||||
|         if quarantined_by is not None: | ||||
|             sql += "AND safe_from_quarantine = ?" | ||||
|             rows = [(quarantined_by, media_id, False) for media_id in local_mxcs] | ||||
|             txn.executemany( | ||||
|                 sql, [(quarantined_by, media_id, False) for media_id in local_mxcs] | ||||
|             ) | ||||
|         # remove from quarantine | ||||
|         else: | ||||
|             rows = [(quarantined_by, media_id) for media_id in local_mxcs] | ||||
|             txn.executemany( | ||||
|                 sql, [(quarantined_by, media_id) for media_id in local_mxcs] | ||||
|             ) | ||||
| 
 | ||||
|         txn.executemany(sql, rows) | ||||
|         # Note that a rowcount of -1 can be used to indicate no rows were affected. | ||||
|         total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 | ||||
| 
 | ||||
|  | @ -960,7 +978,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|     async def get_rooms_for_retention_period_in_range( | ||||
|         self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False | ||||
|     ) -> Dict[str, dict]: | ||||
|     ) -> Dict[str, Dict[str, Optional[int]]]: | ||||
|         """Retrieves all of the rooms within the given retention range. | ||||
| 
 | ||||
|         Optionally includes the rooms which don't have a retention policy. | ||||
|  | @ -980,7 +998,9 @@ class RoomWorkerStore(SQLBaseStore): | |||
|             "min_lifetime" (int|None), and "max_lifetime" (int|None). | ||||
|         """ | ||||
| 
 | ||||
|         def get_rooms_for_retention_period_in_range_txn(txn): | ||||
|         def get_rooms_for_retention_period_in_range_txn( | ||||
|             txn: LoggingTransaction, | ||||
|         ) -> Dict[str, Dict[str, Optional[int]]]: | ||||
|             range_conditions = [] | ||||
|             args = [] | ||||
| 
 | ||||
|  | @ -1067,8 +1087,6 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
|     ): | ||||
|         super().__init__(database, db_conn, hs) | ||||
| 
 | ||||
|         self.config = hs.config | ||||
| 
 | ||||
|         self.db_pool.updates.register_background_update_handler( | ||||
|             "insert_room_retention", | ||||
|             self._background_insert_retention, | ||||
|  | @ -1099,7 +1117,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
|             self._background_populate_rooms_creator_column, | ||||
|         ) | ||||
| 
 | ||||
|     async def _background_insert_retention(self, progress, batch_size): | ||||
|     async def _background_insert_retention( | ||||
|         self, progress: JsonDict, batch_size: int | ||||
|     ) -> int: | ||||
|         """Retrieves a list of all rooms within a range and inserts an entry for each of | ||||
|         them into the room_retention table. | ||||
|         NULLs the property's columns if missing from the retention event in the room's | ||||
|  | @ -1109,7 +1129,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
| 
 | ||||
|         last_room = progress.get("room_id", "") | ||||
| 
 | ||||
|         def _background_insert_retention_txn(txn): | ||||
|         def _background_insert_retention_txn(txn: LoggingTransaction) -> bool: | ||||
|             txn.execute( | ||||
|                 """ | ||||
|                 SELECT state.room_id, state.event_id, events.json | ||||
|  | @ -1168,15 +1188,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
|         return batch_size | ||||
| 
 | ||||
|     async def _background_add_rooms_room_version_column( | ||||
|         self, progress: dict, batch_size: int | ||||
|     ): | ||||
|         self, progress: JsonDict, batch_size: int | ||||
|     ) -> int: | ||||
|         """Background update to go and add room version information to `rooms` | ||||
|         table from `current_state_events` table. | ||||
|         """ | ||||
| 
 | ||||
|         last_room_id = progress.get("room_id", "") | ||||
| 
 | ||||
|         def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): | ||||
|         def _background_add_rooms_room_version_column_txn( | ||||
|             txn: LoggingTransaction, | ||||
|         ) -> bool: | ||||
|             sql = """ | ||||
|                 SELECT room_id, json FROM current_state_events | ||||
|                 INNER JOIN event_json USING (room_id, event_id) | ||||
|  | @ -1237,7 +1259,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
|         return batch_size | ||||
| 
 | ||||
|     async def _remove_tombstoned_rooms_from_directory( | ||||
|         self, progress, batch_size | ||||
|         self, progress: JsonDict, batch_size: int | ||||
|     ) -> int: | ||||
|         """Removes any rooms with tombstone events from the room directory | ||||
| 
 | ||||
|  | @ -1247,7 +1269,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
| 
 | ||||
|         last_room = progress.get("room_id", "") | ||||
| 
 | ||||
|         def _get_rooms(txn): | ||||
|         def _get_rooms(txn: LoggingTransaction) -> List[str]: | ||||
|             txn.execute( | ||||
|                 """ | ||||
|                 SELECT room_id | ||||
|  | @ -1285,7 +1307,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
|         return len(rooms) | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def set_room_is_public(self, room_id, is_public): | ||||
|     def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]: | ||||
|         # this will need to be implemented if a background update is performed with | ||||
|         # existing (tombstoned, public) rooms in the database. | ||||
|         # | ||||
|  | @ -1332,7 +1354,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
|         32-bit integer field. | ||||
|         """ | ||||
| 
 | ||||
|         def process(txn: Cursor) -> int: | ||||
|         def process(txn: LoggingTransaction) -> int: | ||||
|             last_room = progress.get("last_room", "") | ||||
|             txn.execute( | ||||
|                 """ | ||||
|  | @ -1389,15 +1411,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
|         return 0 | ||||
| 
 | ||||
|     async def _background_populate_rooms_creator_column( | ||||
|         self, progress: dict, batch_size: int | ||||
|     ): | ||||
|         self, progress: JsonDict, batch_size: int | ||||
|     ) -> int: | ||||
|         """Background update to go and add creator information to `rooms` | ||||
|         table from `current_state_events` table. | ||||
|         """ | ||||
| 
 | ||||
|         last_room_id = progress.get("room_id", "") | ||||
| 
 | ||||
|         def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction): | ||||
|         def _background_populate_rooms_creator_column_txn( | ||||
|             txn: LoggingTransaction, | ||||
|         ) -> bool: | ||||
|             sql = """ | ||||
|                 SELECT room_id, json FROM event_json | ||||
|                 INNER JOIN rooms AS room USING (room_id) | ||||
|  | @ -1448,7 +1472,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
|         return batch_size | ||||
| 
 | ||||
| 
 | ||||
| class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | ||||
| class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): | ||||
|     def __init__( | ||||
|         self, | ||||
|         database: DatabasePool, | ||||
|  | @ -1457,11 +1481,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
|     ): | ||||
|         super().__init__(database, db_conn, hs) | ||||
| 
 | ||||
|         self.config = hs.config | ||||
|         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") | ||||
| 
 | ||||
|     async def upsert_room_on_join( | ||||
|         self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] | ||||
|     ): | ||||
|     ) -> None: | ||||
|         """Ensure that the room is stored in the table | ||||
| 
 | ||||
|         Called when we join a room over federation, and overwrites any room version | ||||
|  | @ -1507,7 +1531,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
| 
 | ||||
|     async def maybe_store_room_on_outlier_membership( | ||||
|         self, room_id: str, room_version: RoomVersion | ||||
|     ): | ||||
|     ) -> None: | ||||
|         """ | ||||
|         When we receive an invite or any other event over federation that may relate to a room | ||||
|         we are not in, store the version of the room if we don't already know the room version. | ||||
|  | @ -1547,8 +1571,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
|         self.hs.get_notifier().on_new_replication_data() | ||||
| 
 | ||||
|     async def set_room_is_public_appservice( | ||||
|         self, room_id, appservice_id, network_id, is_public | ||||
|     ): | ||||
|         self, room_id: str, appservice_id: str, network_id: str, is_public: bool | ||||
|     ) -> None: | ||||
|         """Edit the appservice/network specific public room list. | ||||
| 
 | ||||
|         Each appservice can have a number of published room lists associated | ||||
|  | @ -1557,11 +1581,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
|         network. | ||||
| 
 | ||||
|         Args: | ||||
|             room_id (str) | ||||
|             appservice_id (str) | ||||
|             network_id (str) | ||||
|             is_public (bool): Whether to publish or unpublish the room from the | ||||
|                 list. | ||||
|             room_id | ||||
|             appservice_id | ||||
|             network_id | ||||
|             is_public: Whether to publish or unpublish the room from the list. | ||||
|         """ | ||||
| 
 | ||||
|         if is_public: | ||||
|  | @ -1626,7 +1649,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
|             event_report: json list of information from event report | ||||
|         """ | ||||
| 
 | ||||
|         def _get_event_report_txn(txn, report_id): | ||||
|         def _get_event_report_txn( | ||||
|             txn: LoggingTransaction, report_id: int | ||||
|         ) -> Optional[Dict[str, Any]]: | ||||
| 
 | ||||
|             sql = """ | ||||
|                 SELECT | ||||
|  | @ -1698,9 +1723,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
|             count: total number of event reports matching the filter criteria | ||||
|         """ | ||||
| 
 | ||||
|         def _get_event_reports_paginate_txn(txn): | ||||
|         def _get_event_reports_paginate_txn( | ||||
|             txn: LoggingTransaction, | ||||
|         ) -> Tuple[List[Dict[str, Any]], int]: | ||||
|             filters = [] | ||||
|             args = [] | ||||
|             args: List[object] = [] | ||||
| 
 | ||||
|             if user_id: | ||||
|                 filters.append("er.user_id LIKE ?") | ||||
|  | @ -1724,7 +1751,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
|                 where_clause | ||||
|             ) | ||||
|             txn.execute(sql, args) | ||||
|             count = txn.fetchone()[0] | ||||
|             count = cast(Tuple[int], txn.fetchone())[0] | ||||
| 
 | ||||
|             sql = """ | ||||
|                 SELECT | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Sean Quah
						Sean Quah