Add type hints to `synapse/storage/databases/main/room.py` (#11575)

pull/11590/head
Sean Quah 2021-12-15 18:00:48 +00:00 committed by GitHub
parent f901f8b70e
commit c7fe32edb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 108 additions and 77 deletions

1
changelog.d/11575.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to storage classes.

View File

@ -37,7 +37,6 @@ exclude = (?x)
|synapse/storage/databases/main/purge_events.py |synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/room.py
|synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py |synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py |synapse/storage/databases/main/state.py
@ -205,6 +204,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.events_worker] [mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.room]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.room_batch] [mypy-synapse.storage.databases.main.room_batch]
disallow_untyped_defs = True disallow_untyped_defs = True

View File

@ -1020,7 +1020,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Add new room to the room directory if the old room was there # Add new room to the room directory if the old room was there
# Remove old room from the room directory # Remove old room from the room directory
old_room = await self.store.get_room(old_room_id) old_room = await self.store.get_room(old_room_id)
if old_room and old_room["is_public"]: if old_room is not None and old_room["is_public"]:
await self.store.set_room_is_public(old_room_id, False) await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True) await self.store.set_room_is_public(room_id, True)
@ -1031,7 +1031,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
local_group_ids = await self.store.get_local_groups_for_room(old_room_id) local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
for group_id in local_group_ids: for group_id in local_group_ids:
# Add new the new room to those groups # Add new the new room to those groups
await self.store.add_room_to_group(group_id, room_id, old_room["is_public"]) await self.store.add_room_to_group(
group_id, room_id, old_room is not None and old_room["is_public"]
)
# Remove the old room from those groups # Remove the old room from those groups
await self.store.remove_room_from_group(group_id, old_room_id) await self.store.remove_room_from_group(group_id, old_room_id)

View File

@ -149,7 +149,6 @@ class DataStore(
], ],
) )
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._group_updates_id_gen = StreamIdGenerator( self._group_updates_id_gen = StreamIdGenerator(

View File

@ -17,7 +17,7 @@ import collections
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast
from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
@ -29,8 +29,9 @@ from synapse.storage.database import (
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -75,7 +76,7 @@ class RoomSortOrder(Enum):
STATE_EVENTS = "state_events" STATE_EVENTS = "state_events"
class RoomWorkerStore(SQLBaseStore): class RoomWorkerStore(CacheInvalidationWorkerStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
@ -92,7 +93,7 @@ class RoomWorkerStore(SQLBaseStore):
room_creator_user_id: str, room_creator_user_id: str,
is_public: bool, is_public: bool,
room_version: RoomVersion, room_version: RoomVersion,
): ) -> None:
"""Stores a room. """Stores a room.
Args: Args:
@ -120,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e) logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.") raise StoreError(500, "Problem creating room.")
async def get_room(self, room_id: str) -> dict: async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve a room. """Retrieve a room.
Args: Args:
@ -145,7 +146,9 @@ class RoomWorkerStore(SQLBaseStore):
A dict containing the room information, or None if the room is unknown. A dict containing the room information, or None if the room is unknown.
""" """
def get_room_with_stats_txn(txn, room_id): def get_room_with_stats_txn(
txn: LoggingTransaction, room_id: str
) -> Optional[Dict[str, Any]]:
sql = """ sql = """
SELECT room_id, state.name, state.canonical_alias, curr.joined_members, SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
@ -194,7 +197,7 @@ class RoomWorkerStore(SQLBaseStore):
ignore_non_federatable: If true filters out non-federatable rooms ignore_non_federatable: If true filters out non-federatable rooms
""" """
def _count_public_rooms_txn(txn): def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
query_args = [] query_args = []
if network_tuple: if network_tuple:
@ -235,7 +238,7 @@ class RoomWorkerStore(SQLBaseStore):
} }
txn.execute(sql, query_args) txn.execute(sql, query_args)
return txn.fetchone()[0] return cast(Tuple[int], txn.fetchone())[0]
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"count_public_rooms", _count_public_rooms_txn "count_public_rooms", _count_public_rooms_txn
@ -244,11 +247,11 @@ class RoomWorkerStore(SQLBaseStore):
async def get_room_count(self) -> int: async def get_room_count(self) -> int:
"""Retrieve the total number of rooms.""" """Retrieve the total number of rooms."""
def f(txn): def f(txn: LoggingTransaction) -> int:
sql = "SELECT count(*) FROM rooms" sql = "SELECT count(*) FROM rooms"
txn.execute(sql) txn.execute(sql)
row = txn.fetchone() row = cast(Tuple[int], txn.fetchone())
return row[0] or 0 return row[0]
return await self.db_pool.runInteraction("get_rooms", f) return await self.db_pool.runInteraction("get_rooms", f)
@ -260,7 +263,7 @@ class RoomWorkerStore(SQLBaseStore):
bounds: Optional[Tuple[int, str]], bounds: Optional[Tuple[int, str]],
forwards: bool, forwards: bool,
ignore_non_federatable: bool = False, ignore_non_federatable: bool = False,
): ) -> List[Dict[str, Any]]:
"""Gets the largest public rooms (where largest is in terms of joined """Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table). members, as tracked in the statistics table).
@ -381,7 +384,9 @@ class RoomWorkerStore(SQLBaseStore):
LIMIT ? LIMIT ?
""" """
def _get_largest_public_rooms_txn(txn): def _get_largest_public_rooms_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Any]]:
txn.execute(sql, query_args) txn.execute(sql, query_args)
results = self.db_pool.cursor_to_dict(txn) results = self.db_pool.cursor_to_dict(txn)
@ -444,7 +449,7 @@ class RoomWorkerStore(SQLBaseStore):
""" """
# Filter room names by a string # Filter room names by a string
where_statement = "" where_statement = ""
search_pattern = [] search_pattern: List[object] = []
if search_term: if search_term:
where_statement = """ where_statement = """
WHERE LOWER(state.name) LIKE ? WHERE LOWER(state.name) LIKE ?
@ -552,7 +557,9 @@ class RoomWorkerStore(SQLBaseStore):
where_statement, where_statement,
) )
def _get_rooms_paginate_txn(txn): def _get_rooms_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], int]:
# Add the search term into the WHERE clause # Add the search term into the WHERE clause
# and execute the data query # and execute the data query
txn.execute(info_sql, search_pattern + [limit, start]) txn.execute(info_sql, search_pattern + [limit, start])
@ -584,7 +591,7 @@ class RoomWorkerStore(SQLBaseStore):
# Add the search term into the WHERE clause if present # Add the search term into the WHERE clause if present
txn.execute(count_sql, search_pattern) txn.execute(count_sql, search_pattern)
room_count = txn.fetchone() room_count = cast(Tuple[int], txn.fetchone())
return rooms, room_count[0] return rooms, room_count[0]
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -629,7 +636,7 @@ class RoomWorkerStore(SQLBaseStore):
burst_count: How many actions that can be performed before being limited. burst_count: How many actions that can be performed before being limited.
""" """
def set_ratelimit_txn(txn): def set_ratelimit_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="ratelimit_override", table="ratelimit_override",
@ -652,7 +659,7 @@ class RoomWorkerStore(SQLBaseStore):
user_id: user ID of the user user_id: user ID of the user
""" """
def delete_ratelimit_txn(txn): def delete_ratelimit_txn(txn: LoggingTransaction) -> None:
row = self.db_pool.simple_select_one_txn( row = self.db_pool.simple_select_one_txn(
txn, txn,
table="ratelimit_override", table="ratelimit_override",
@ -676,7 +683,7 @@ class RoomWorkerStore(SQLBaseStore):
await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn) await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
@cached() @cached()
async def get_retention_policy_for_room(self, room_id): async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]:
"""Get the retention policy for a given room. """Get the retention policy for a given room.
If no retention policy has been found for this room, returns a policy defined If no retention policy has been found for this room, returns a policy defined
@ -685,13 +692,15 @@ class RoomWorkerStore(SQLBaseStore):
configuration). configuration).
Args: Args:
room_id (str): The ID of the room to get the retention policy of. room_id: The ID of the room to get the retention policy of.
Returns: Returns:
dict[int, int]: "min_lifetime" and "max_lifetime" for this room. A dict containing "min_lifetime" and "max_lifetime" for this room.
""" """
def get_retention_policy_for_room_txn(txn): def get_retention_policy_for_room_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Optional[int]]]:
txn.execute( txn.execute(
""" """
SELECT min_lifetime, max_lifetime FROM room_retention SELECT min_lifetime, max_lifetime FROM room_retention
@ -716,19 +725,23 @@ class RoomWorkerStore(SQLBaseStore):
"max_lifetime": self.config.retention.retention_default_max_lifetime, "max_lifetime": self.config.retention.retention_default_max_lifetime,
} }
row = ret[0] min_lifetime = ret[0]["min_lifetime"]
max_lifetime = ret[0]["max_lifetime"]
# If one of the room's policy's attributes isn't defined, use the matching # If one of the room's policy's attributes isn't defined, use the matching
# attribute from the default policy. # attribute from the default policy.
# The default values will be None if no default policy has been defined, or if one # The default values will be None if no default policy has been defined, or if one
# of the attributes is missing from the default policy. # of the attributes is missing from the default policy.
if row["min_lifetime"] is None: if min_lifetime is None:
row["min_lifetime"] = self.config.retention.retention_default_min_lifetime min_lifetime = self.config.retention.retention_default_min_lifetime
if row["max_lifetime"] is None: if max_lifetime is None:
row["max_lifetime"] = self.config.retention.retention_default_max_lifetime max_lifetime = self.config.retention.retention_default_max_lifetime
return row return {
"min_lifetime": min_lifetime,
"max_lifetime": max_lifetime,
}
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room """Retrieves all the local and remote media MXC URIs in a given room
@ -740,7 +753,9 @@ class RoomWorkerStore(SQLBaseStore):
The local and remote media as a lists of the media IDs. The local and remote media as a lists of the media IDs.
""" """
def _get_media_mxcs_in_room_txn(txn): def _get_media_mxcs_in_room_txn(
txn: LoggingTransaction,
) -> Tuple[List[str], List[str]]:
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = [] local_media_mxcs = []
remote_media_mxcs = [] remote_media_mxcs = []
@ -766,7 +781,7 @@ class RoomWorkerStore(SQLBaseStore):
logger.info("Quarantining media in room: %s", room_id) logger.info("Quarantining media in room: %s", room_id)
def _quarantine_media_in_room_txn(txn): def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int:
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
return self._quarantine_media_txn( return self._quarantine_media_txn(
txn, local_mxcs, remote_mxcs, quarantined_by txn, local_mxcs, remote_mxcs, quarantined_by
@ -776,13 +791,11 @@ class RoomWorkerStore(SQLBaseStore):
"quarantine_media_in_room", _quarantine_media_in_room_txn "quarantine_media_in_room", _quarantine_media_in_room_txn
) )
def _get_media_mxcs_in_room_txn(self, txn, room_id): def _get_media_mxcs_in_room_txn(
self, txn: LoggingTransaction, room_id: str
) -> Tuple[List[str], List[Tuple[str, str]]]:
"""Retrieves all the local and remote media MXC URIs in a given room """Retrieves all the local and remote media MXC URIs in a given room
Args:
txn (cursor)
room_id (str)
Returns: Returns:
The local and remote media as a lists of tuples where the key is The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID. the hostname and the value is the media ID.
@ -850,7 +863,7 @@ class RoomWorkerStore(SQLBaseStore):
logger.info("Quarantining media: %s/%s", server_name, media_id) logger.info("Quarantining media: %s/%s", server_name, media_id)
is_local = server_name == self.config.server.server_name is_local = server_name == self.config.server.server_name
def _quarantine_media_by_id_txn(txn): def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
local_mxcs = [media_id] if is_local else [] local_mxcs = [media_id] if is_local else []
remote_mxcs = [(server_name, media_id)] if not is_local else [] remote_mxcs = [(server_name, media_id)] if not is_local else []
@ -872,7 +885,7 @@ class RoomWorkerStore(SQLBaseStore):
quarantined_by: The ID of the user who made the quarantine request quarantined_by: The ID of the user who made the quarantine request
""" """
def _quarantine_media_by_user_txn(txn): def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int:
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
@ -880,7 +893,9 @@ class RoomWorkerStore(SQLBaseStore):
"quarantine_media_by_user", _quarantine_media_by_user_txn "quarantine_media_by_user", _quarantine_media_by_user_txn
) )
def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True): def _get_media_ids_by_user_txn(
self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True
) -> List[str]:
"""Retrieves local media IDs by a given user """Retrieves local media IDs by a given user
Args: Args:
@ -909,7 +924,7 @@ class RoomWorkerStore(SQLBaseStore):
def _quarantine_media_txn( def _quarantine_media_txn(
self, self,
txn, txn: LoggingTransaction,
local_mxcs: List[str], local_mxcs: List[str],
remote_mxcs: List[Tuple[str, str]], remote_mxcs: List[Tuple[str, str]],
quarantined_by: Optional[str], quarantined_by: Optional[str],
@ -937,12 +952,15 @@ class RoomWorkerStore(SQLBaseStore):
# set quarantine # set quarantine
if quarantined_by is not None: if quarantined_by is not None:
sql += "AND safe_from_quarantine = ?" sql += "AND safe_from_quarantine = ?"
rows = [(quarantined_by, media_id, False) for media_id in local_mxcs] txn.executemany(
sql, [(quarantined_by, media_id, False) for media_id in local_mxcs]
)
# remove from quarantine # remove from quarantine
else: else:
rows = [(quarantined_by, media_id) for media_id in local_mxcs] txn.executemany(
sql, [(quarantined_by, media_id) for media_id in local_mxcs]
)
txn.executemany(sql, rows)
# Note that a rowcount of -1 can be used to indicate no rows were affected. # Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
@ -960,7 +978,7 @@ class RoomWorkerStore(SQLBaseStore):
async def get_rooms_for_retention_period_in_range( async def get_rooms_for_retention_period_in_range(
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
) -> Dict[str, dict]: ) -> Dict[str, Dict[str, Optional[int]]]:
"""Retrieves all of the rooms within the given retention range. """Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy. Optionally includes the rooms which don't have a retention policy.
@ -980,7 +998,9 @@ class RoomWorkerStore(SQLBaseStore):
"min_lifetime" (int|None), and "max_lifetime" (int|None). "min_lifetime" (int|None), and "max_lifetime" (int|None).
""" """
def get_rooms_for_retention_period_in_range_txn(txn): def get_rooms_for_retention_period_in_range_txn(
txn: LoggingTransaction,
) -> Dict[str, Dict[str, Optional[int]]]:
range_conditions = [] range_conditions = []
args = [] args = []
@ -1067,8 +1087,6 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"insert_room_retention", "insert_room_retention",
self._background_insert_retention, self._background_insert_retention,
@ -1099,7 +1117,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self._background_populate_rooms_creator_column, self._background_populate_rooms_creator_column,
) )
async def _background_insert_retention(self, progress, batch_size): async def _background_insert_retention(
self, progress: JsonDict, batch_size: int
) -> int:
"""Retrieves a list of all rooms within a range and inserts an entry for each of """Retrieves a list of all rooms within a range and inserts an entry for each of
them into the room_retention table. them into the room_retention table.
NULLs the property's columns if missing from the retention event in the room's NULLs the property's columns if missing from the retention event in the room's
@ -1109,7 +1129,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
last_room = progress.get("room_id", "") last_room = progress.get("room_id", "")
def _background_insert_retention_txn(txn): def _background_insert_retention_txn(txn: LoggingTransaction) -> bool:
txn.execute( txn.execute(
""" """
SELECT state.room_id, state.event_id, events.json SELECT state.room_id, state.event_id, events.json
@ -1168,15 +1188,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size return batch_size
async def _background_add_rooms_room_version_column( async def _background_add_rooms_room_version_column(
self, progress: dict, batch_size: int self, progress: JsonDict, batch_size: int
): ) -> int:
"""Background update to go and add room version information to `rooms` """Background update to go and add room version information to `rooms`
table from `current_state_events` table. table from `current_state_events` table.
""" """
last_room_id = progress.get("room_id", "") last_room_id = progress.get("room_id", "")
def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): def _background_add_rooms_room_version_column_txn(
txn: LoggingTransaction,
) -> bool:
sql = """ sql = """
SELECT room_id, json FROM current_state_events SELECT room_id, json FROM current_state_events
INNER JOIN event_json USING (room_id, event_id) INNER JOIN event_json USING (room_id, event_id)
@ -1237,7 +1259,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size return batch_size
async def _remove_tombstoned_rooms_from_directory( async def _remove_tombstoned_rooms_from_directory(
self, progress, batch_size self, progress: JsonDict, batch_size: int
) -> int: ) -> int:
"""Removes any rooms with tombstone events from the room directory """Removes any rooms with tombstone events from the room directory
@ -1247,7 +1269,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
last_room = progress.get("room_id", "") last_room = progress.get("room_id", "")
def _get_rooms(txn): def _get_rooms(txn: LoggingTransaction) -> List[str]:
txn.execute( txn.execute(
""" """
SELECT room_id SELECT room_id
@ -1285,7 +1307,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return len(rooms) return len(rooms)
@abstractmethod @abstractmethod
def set_room_is_public(self, room_id, is_public): def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]:
# this will need to be implemented if a background update is performed with # this will need to be implemented if a background update is performed with
# existing (tombstoned, public) rooms in the database. # existing (tombstoned, public) rooms in the database.
# #
@ -1332,7 +1354,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
32-bit integer field. 32-bit integer field.
""" """
def process(txn: Cursor) -> int: def process(txn: LoggingTransaction) -> int:
last_room = progress.get("last_room", "") last_room = progress.get("last_room", "")
txn.execute( txn.execute(
""" """
@ -1389,15 +1411,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return 0 return 0
async def _background_populate_rooms_creator_column( async def _background_populate_rooms_creator_column(
self, progress: dict, batch_size: int self, progress: JsonDict, batch_size: int
): ) -> int:
"""Background update to go and add creator information to `rooms` """Background update to go and add creator information to `rooms`
table from `current_state_events` table. table from `current_state_events` table.
""" """
last_room_id = progress.get("room_id", "") last_room_id = progress.get("room_id", "")
def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction): def _background_populate_rooms_creator_column_txn(
txn: LoggingTransaction,
) -> bool:
sql = """ sql = """
SELECT room_id, json FROM event_json SELECT room_id, json FROM event_json
INNER JOIN rooms AS room USING (room_id) INNER JOIN rooms AS room USING (room_id)
@ -1448,7 +1472,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size return batch_size
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
@ -1457,11 +1481,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
async def upsert_room_on_join( async def upsert_room_on_join(
self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
): ) -> None:
"""Ensure that the room is stored in the table """Ensure that the room is stored in the table
Called when we join a room over federation, and overwrites any room version Called when we join a room over federation, and overwrites any room version
@ -1507,7 +1531,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
async def maybe_store_room_on_outlier_membership( async def maybe_store_room_on_outlier_membership(
self, room_id: str, room_version: RoomVersion self, room_id: str, room_version: RoomVersion
): ) -> None:
""" """
When we receive an invite or any other event over federation that may relate to a room When we receive an invite or any other event over federation that may relate to a room
we are not in, store the version of the room if we don't already know the room version. we are not in, store the version of the room if we don't already know the room version.
@ -1547,8 +1571,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.hs.get_notifier().on_new_replication_data() self.hs.get_notifier().on_new_replication_data()
async def set_room_is_public_appservice( async def set_room_is_public_appservice(
self, room_id, appservice_id, network_id, is_public self, room_id: str, appservice_id: str, network_id: str, is_public: bool
): ) -> None:
"""Edit the appservice/network specific public room list. """Edit the appservice/network specific public room list.
Each appservice can have a number of published room lists associated Each appservice can have a number of published room lists associated
@ -1557,11 +1581,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
network. network.
Args: Args:
room_id (str) room_id
appservice_id (str) appservice_id
network_id (str) network_id
is_public (bool): Whether to publish or unpublish the room from the is_public: Whether to publish or unpublish the room from the list.
list.
""" """
if is_public: if is_public:
@ -1626,7 +1649,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
event_report: json list of information from event report event_report: json list of information from event report
""" """
def _get_event_report_txn(txn, report_id): def _get_event_report_txn(
txn: LoggingTransaction, report_id: int
) -> Optional[Dict[str, Any]]:
sql = """ sql = """
SELECT SELECT
@ -1698,9 +1723,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
count: total number of event reports matching the filter criteria count: total number of event reports matching the filter criteria
""" """
def _get_event_reports_paginate_txn(txn): def _get_event_reports_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], int]:
filters = [] filters = []
args = [] args: List[object] = []
if user_id: if user_id:
filters.append("er.user_id LIKE ?") filters.append("er.user_id LIKE ?")
@ -1724,7 +1751,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
where_clause where_clause
) )
txn.execute(sql, args) txn.execute(sql, args)
count = txn.fetchone()[0] count = cast(Tuple[int], txn.fetchone())[0]
sql = """ sql = """
SELECT SELECT