Require Collections as the parameters for simple_* methods. (#11580)
Instead of Iterable since the generators are not allowed due to the potential for their re-use.pull/11590/head
parent
323151b787
commit
f901f8b70e
|
@ -0,0 +1 @@
|
||||||
|
Add some safety checks that storage functions are used correctly.
|
|
@ -55,6 +55,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.background_updates import BackgroundUpdater
|
from synapse.storage.background_updates import BackgroundUpdater
|
||||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||||
from synapse.storage.types import Connection, Cursor
|
from synapse.storage.types import Connection, Cursor
|
||||||
|
from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -986,7 +987,7 @@ class DatabasePool:
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
keys: Collection[str],
|
keys: Collection[str],
|
||||||
values: Iterable[Iterable[Any]],
|
values: Collection[Collection[Any]],
|
||||||
desc: str,
|
desc: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Executes an INSERT query on the named table.
|
"""Executes an INSERT query on the named table.
|
||||||
|
@ -1427,7 +1428,7 @@ class DatabasePool:
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
retcols: Iterable[str],
|
retcols: Collection[str],
|
||||||
allow_none: Literal[False] = False,
|
allow_none: Literal[False] = False,
|
||||||
desc: str = "simple_select_one",
|
desc: str = "simple_select_one",
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
@ -1438,7 +1439,7 @@ class DatabasePool:
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
retcols: Iterable[str],
|
retcols: Collection[str],
|
||||||
allow_none: Literal[True] = True,
|
allow_none: Literal[True] = True,
|
||||||
desc: str = "simple_select_one",
|
desc: str = "simple_select_one",
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
@ -1448,7 +1449,7 @@ class DatabasePool:
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
retcols: Iterable[str],
|
retcols: Collection[str],
|
||||||
allow_none: bool = False,
|
allow_none: bool = False,
|
||||||
desc: str = "simple_select_one",
|
desc: str = "simple_select_one",
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
@ -1618,7 +1619,7 @@ class DatabasePool:
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Optional[Dict[str, Any]],
|
keyvalues: Optional[Dict[str, Any]],
|
||||||
retcols: Iterable[str],
|
retcols: Collection[str],
|
||||||
desc: str = "simple_select_list",
|
desc: str = "simple_select_list",
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
|
@ -1681,7 +1682,7 @@ class DatabasePool:
|
||||||
table: str,
|
table: str,
|
||||||
column: str,
|
column: str,
|
||||||
iterable: Iterable[Any],
|
iterable: Iterable[Any],
|
||||||
retcols: Iterable[str],
|
retcols: Collection[str],
|
||||||
keyvalues: Optional[Dict[str, Any]] = None,
|
keyvalues: Optional[Dict[str, Any]] = None,
|
||||||
desc: str = "simple_select_many_batch",
|
desc: str = "simple_select_many_batch",
|
||||||
batch_size: int = 100,
|
batch_size: int = 100,
|
||||||
|
@ -1704,16 +1705,7 @@ class DatabasePool:
|
||||||
|
|
||||||
results: List[Dict[str, Any]] = []
|
results: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
if not iterable:
|
for chunk in batch_iter(iterable, batch_size):
|
||||||
return results
|
|
||||||
|
|
||||||
# iterables can not be sliced, so convert it to a list first
|
|
||||||
it_list = list(iterable)
|
|
||||||
|
|
||||||
chunks = [
|
|
||||||
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
|
|
||||||
]
|
|
||||||
for chunk in chunks:
|
|
||||||
rows = await self.runInteraction(
|
rows = await self.runInteraction(
|
||||||
desc,
|
desc,
|
||||||
self.simple_select_many_txn,
|
self.simple_select_many_txn,
|
||||||
|
@ -1853,7 +1845,7 @@ class DatabasePool:
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
retcols: Iterable[str],
|
retcols: Collection[str],
|
||||||
allow_none: bool = False,
|
allow_none: bool = False,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
||||||
|
@ -2146,7 +2138,7 @@ class DatabasePool:
|
||||||
table: str,
|
table: str,
|
||||||
term: Optional[str],
|
term: Optional[str],
|
||||||
col: str,
|
col: str,
|
||||||
retcols: Iterable[str],
|
retcols: Collection[str],
|
||||||
desc="simple_search_list",
|
desc="simple_search_list",
|
||||||
) -> Optional[List[Dict[str, Any]]]:
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
|
|
|
@ -22,7 +22,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -196,27 +196,6 @@ class PusherWorkerStore(SQLBaseStore):
|
||||||
# This only exists for the cachedList decorator
|
# This only exists for the cachedList decorator
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedList(
|
|
||||||
cached_method_name="get_if_user_has_pusher",
|
|
||||||
list_name="user_ids",
|
|
||||||
num_args=1,
|
|
||||||
)
|
|
||||||
async def get_if_users_have_pushers(
|
|
||||||
self, user_ids: Iterable[str]
|
|
||||||
) -> Dict[str, bool]:
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
|
||||||
table="pushers",
|
|
||||||
column="user_name",
|
|
||||||
iterable=user_ids,
|
|
||||||
retcols=["user_name"],
|
|
||||||
desc="get_if_users_have_pushers",
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {user_id: False for user_id in user_ids}
|
|
||||||
result.update({r["user_name"]: True for r in rows})
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def update_pusher_last_stream_ordering(
|
async def update_pusher_last_stream_ordering(
|
||||||
self, app_id, pushkey, user_id, last_stream_ordering
|
self, app_id, pushkey, user_id, last_stream_ordering
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
Loading…
Reference in New Issue