Additional type hints for relations database class. (#11205)
parent
0e16b418f6
commit
56e281bf6c
|
@ -0,0 +1 @@
|
||||||
|
Improve type hints for the relations datastore.
|
1
mypy.ini
1
mypy.ini
|
@ -53,6 +53,7 @@ files =
|
||||||
synapse/storage/databases/main/keys.py,
|
synapse/storage/databases/main/keys.py,
|
||||||
synapse/storage/databases/main/pusher.py,
|
synapse/storage/databases/main/pusher.py,
|
||||||
synapse/storage/databases/main/registration.py,
|
synapse/storage/databases/main/registration.py,
|
||||||
|
synapse/storage/databases/main/relations.py,
|
||||||
synapse/storage/databases/main/session.py,
|
synapse/storage/databases/main/session.py,
|
||||||
synapse/storage/databases/main/stream.py,
|
synapse/storage/databases/main/stream.py,
|
||||||
synapse/storage/databases/main/ui_auth.py,
|
synapse/storage/databases/main/ui_auth.py,
|
||||||
|
|
|
@ -13,13 +13,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from synapse.api.constants import RelationTypes
|
from synapse.api.constants import RelationTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
from synapse.storage.database import LoggingTransaction
|
||||||
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
||||||
from synapse.storage.relations import (
|
from synapse.storage.relations import (
|
||||||
AggregationPaginationToken,
|
AggregationPaginationToken,
|
||||||
|
@ -63,7 +64,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
where_clause = ["relates_to_id = ?"]
|
where_clause = ["relates_to_id = ?"]
|
||||||
where_args = [event_id]
|
where_args: List[Union[str, int]] = [event_id]
|
||||||
|
|
||||||
if relation_type is not None:
|
if relation_type is not None:
|
||||||
where_clause.append("relation_type = ?")
|
where_clause.append("relation_type = ?")
|
||||||
|
@ -80,8 +81,8 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
pagination_clause = generate_pagination_where_clause(
|
pagination_clause = generate_pagination_where_clause(
|
||||||
direction=direction,
|
direction=direction,
|
||||||
column_names=("topological_ordering", "stream_ordering"),
|
column_names=("topological_ordering", "stream_ordering"),
|
||||||
from_token=attr.astuple(from_token) if from_token else None,
|
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
|
||||||
to_token=attr.astuple(to_token) if to_token else None,
|
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
|
||||||
engine=self.database_engine,
|
engine=self.database_engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -106,7 +107,9 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
order,
|
order,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_recent_references_for_event_txn(txn):
|
def _get_recent_references_for_event_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> PaginationChunk:
|
||||||
txn.execute(sql, where_args + [limit + 1])
|
txn.execute(sql, where_args + [limit + 1])
|
||||||
|
|
||||||
last_topo_id = None
|
last_topo_id = None
|
||||||
|
@ -160,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
where_clause = ["relates_to_id = ?", "relation_type = ?"]
|
where_clause = ["relates_to_id = ?", "relation_type = ?"]
|
||||||
where_args = [event_id, RelationTypes.ANNOTATION]
|
where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
|
||||||
|
|
||||||
if event_type:
|
if event_type:
|
||||||
where_clause.append("type = ?")
|
where_clause.append("type = ?")
|
||||||
|
@ -169,8 +172,8 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
having_clause = generate_pagination_where_clause(
|
having_clause = generate_pagination_where_clause(
|
||||||
direction=direction,
|
direction=direction,
|
||||||
column_names=("COUNT(*)", "MAX(stream_ordering)"),
|
column_names=("COUNT(*)", "MAX(stream_ordering)"),
|
||||||
from_token=attr.astuple(from_token) if from_token else None,
|
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
|
||||||
to_token=attr.astuple(to_token) if to_token else None,
|
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
|
||||||
engine=self.database_engine,
|
engine=self.database_engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -199,7 +202,9 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
having_clause=having_clause,
|
having_clause=having_clause,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_aggregation_groups_for_event_txn(txn):
|
def _get_aggregation_groups_for_event_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> PaginationChunk:
|
||||||
txn.execute(sql, where_args + [limit + 1])
|
txn.execute(sql, where_args + [limit + 1])
|
||||||
|
|
||||||
next_batch = None
|
next_batch = None
|
||||||
|
@ -254,11 +259,12 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_applicable_edit_txn(txn):
|
def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
|
||||||
txn.execute(sql, (event_id, RelationTypes.REPLACE))
|
txn.execute(sql, (event_id, RelationTypes.REPLACE))
|
||||||
row = txn.fetchone()
|
row = txn.fetchone()
|
||||||
if row:
|
if row:
|
||||||
return row[0]
|
return row[0]
|
||||||
|
return None
|
||||||
|
|
||||||
edit_id = await self.db_pool.runInteraction(
|
edit_id = await self.db_pool.runInteraction(
|
||||||
"get_applicable_edit", _get_applicable_edit_txn
|
"get_applicable_edit", _get_applicable_edit_txn
|
||||||
|
@ -267,7 +273,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
if not edit_id:
|
if not edit_id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return await self.get_event(edit_id, allow_none=True)
|
return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined]
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def get_thread_summary(
|
async def get_thread_summary(
|
||||||
|
@ -283,7 +289,9 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
The number of items in the thread and the most recent response, if any.
|
The number of items in the thread and the most recent response, if any.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
|
def _get_thread_summary_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[int, Optional[str]]:
|
||||||
# Fetch the count of threaded events and the latest event ID.
|
# Fetch the count of threaded events and the latest event ID.
|
||||||
# TODO Should this only allow m.room.message events.
|
# TODO Should this only allow m.room.message events.
|
||||||
sql = """
|
sql = """
|
||||||
|
@ -312,7 +320,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
AND relation_type = ?
|
AND relation_type = ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (event_id, RelationTypes.THREAD))
|
txn.execute(sql, (event_id, RelationTypes.THREAD))
|
||||||
count = txn.fetchone()[0]
|
count = txn.fetchone()[0] # type: ignore[index]
|
||||||
|
|
||||||
return count, latest_event_id
|
return count, latest_event_id
|
||||||
|
|
||||||
|
@ -322,7 +330,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
latest_event = None
|
latest_event = None
|
||||||
if latest_event_id:
|
if latest_event_id:
|
||||||
latest_event = await self.get_event(latest_event_id, allow_none=True)
|
latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]
|
||||||
|
|
||||||
return count, latest_event
|
return count, latest_event
|
||||||
|
|
||||||
|
@ -354,7 +362,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
LIMIT 1;
|
LIMIT 1;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_if_user_has_annotated_event(txn):
|
def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
sql,
|
sql,
|
||||||
(
|
(
|
||||||
|
|
Loading…
Reference in New Issue