Additional type hints for relations database class. (#11205)

pull/11209/head
Patrick Cloke 2021-10-28 14:35:12 -04:00 committed by GitHub
parent 0e16b418f6
commit 56e281bf6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 15 deletions

1
changelog.d/11205.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints for the relations datastore.

View File

@ -53,6 +53,7 @@ files =
synapse/storage/databases/main/keys.py, synapse/storage/databases/main/keys.py,
synapse/storage/databases/main/pusher.py, synapse/storage/databases/main/pusher.py,
synapse/storage/databases/main/registration.py, synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/relations.py,
synapse/storage/databases/main/session.py, synapse/storage/databases/main/session.py,
synapse/storage/databases/main/stream.py, synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py, synapse/storage/databases/main/ui_auth.py,

View File

@ -13,13 +13,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import attr import attr
from synapse.api.constants import RelationTypes from synapse.api.constants import RelationTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import ( from synapse.storage.relations import (
AggregationPaginationToken, AggregationPaginationToken,
@ -63,7 +64,7 @@ class RelationsWorkerStore(SQLBaseStore):
""" """
where_clause = ["relates_to_id = ?"] where_clause = ["relates_to_id = ?"]
where_args = [event_id] where_args: List[Union[str, int]] = [event_id]
if relation_type is not None: if relation_type is not None:
where_clause.append("relation_type = ?") where_clause.append("relation_type = ?")
@ -80,8 +81,8 @@ class RelationsWorkerStore(SQLBaseStore):
pagination_clause = generate_pagination_where_clause( pagination_clause = generate_pagination_where_clause(
direction=direction, direction=direction,
column_names=("topological_ordering", "stream_ordering"), column_names=("topological_ordering", "stream_ordering"),
from_token=attr.astuple(from_token) if from_token else None, from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
to_token=attr.astuple(to_token) if to_token else None, to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine, engine=self.database_engine,
) )
@ -106,7 +107,9 @@ class RelationsWorkerStore(SQLBaseStore):
order, order,
) )
def _get_recent_references_for_event_txn(txn): def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1]) txn.execute(sql, where_args + [limit + 1])
last_topo_id = None last_topo_id = None
@ -160,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore):
""" """
where_clause = ["relates_to_id = ?", "relation_type = ?"] where_clause = ["relates_to_id = ?", "relation_type = ?"]
where_args = [event_id, RelationTypes.ANNOTATION] where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
if event_type: if event_type:
where_clause.append("type = ?") where_clause.append("type = ?")
@ -169,8 +172,8 @@ class RelationsWorkerStore(SQLBaseStore):
having_clause = generate_pagination_where_clause( having_clause = generate_pagination_where_clause(
direction=direction, direction=direction,
column_names=("COUNT(*)", "MAX(stream_ordering)"), column_names=("COUNT(*)", "MAX(stream_ordering)"),
from_token=attr.astuple(from_token) if from_token else None, from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
to_token=attr.astuple(to_token) if to_token else None, to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine, engine=self.database_engine,
) )
@ -199,7 +202,9 @@ class RelationsWorkerStore(SQLBaseStore):
having_clause=having_clause, having_clause=having_clause,
) )
def _get_aggregation_groups_for_event_txn(txn): def _get_aggregation_groups_for_event_txn(
txn: LoggingTransaction,
) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1]) txn.execute(sql, where_args + [limit + 1])
next_batch = None next_batch = None
@ -254,11 +259,12 @@ class RelationsWorkerStore(SQLBaseStore):
LIMIT 1 LIMIT 1
""" """
def _get_applicable_edit_txn(txn): def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id, RelationTypes.REPLACE)) txn.execute(sql, (event_id, RelationTypes.REPLACE))
row = txn.fetchone() row = txn.fetchone()
if row: if row:
return row[0] return row[0]
return None
edit_id = await self.db_pool.runInteraction( edit_id = await self.db_pool.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn "get_applicable_edit", _get_applicable_edit_txn
@ -267,7 +273,7 @@ class RelationsWorkerStore(SQLBaseStore):
if not edit_id: if not edit_id:
return None return None
return await self.get_event(edit_id, allow_none=True) return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined]
@cached() @cached()
async def get_thread_summary( async def get_thread_summary(
@ -283,7 +289,9 @@ class RelationsWorkerStore(SQLBaseStore):
The number of items in the thread and the most recent response, if any. The number of items in the thread and the most recent response, if any.
""" """
def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]: def _get_thread_summary_txn(
txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]:
# Fetch the count of threaded events and the latest event ID. # Fetch the count of threaded events and the latest event ID.
# TODO Should this only allow m.room.message events. # TODO Should this only allow m.room.message events.
sql = """ sql = """
@ -312,7 +320,7 @@ class RelationsWorkerStore(SQLBaseStore):
AND relation_type = ? AND relation_type = ?
""" """
txn.execute(sql, (event_id, RelationTypes.THREAD)) txn.execute(sql, (event_id, RelationTypes.THREAD))
count = txn.fetchone()[0] count = txn.fetchone()[0] # type: ignore[index]
return count, latest_event_id return count, latest_event_id
@ -322,7 +330,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_event = None latest_event = None
if latest_event_id: if latest_event_id:
latest_event = await self.get_event(latest_event_id, allow_none=True) latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]
return count, latest_event return count, latest_event
@ -354,7 +362,7 @@ class RelationsWorkerStore(SQLBaseStore):
LIMIT 1; LIMIT 1;
""" """
def _get_if_user_has_annotated_event(txn): def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
txn.execute( txn.execute(
sql, sql,
( (