checks for generators in database functions (#11564)
A couple of safety-checks to hopefully stop people doing what I just did, and create a storage function which only works the first time it is called (and not when it is re-run due to a database concurrency error or similar).pull/11572/head
parent
eb39da6782
commit
ff6fd52160
|
@ -0,0 +1 @@
|
|||
Add some safety checks that storage functions are used correctly.
|
|
@ -13,8 +13,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
import types
|
||||
from collections import defaultdict
|
||||
from sys import intern
|
||||
from time import monotonic as monotonic_time
|
||||
|
@ -526,6 +528,12 @@ class DatabasePool:
|
|||
the function will correctly handle being aborted and retried half way
|
||||
through its execution.
|
||||
|
||||
Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators,
|
||||
since they could be evaluated multiple times (which would produce an empty
|
||||
result on the second or subsequent evaluation). Likewise, the closure of `func`
|
||||
must not reference any generators. This method attempts to detect such usage
|
||||
and will log an error.
|
||||
|
||||
Args:
|
||||
conn
|
||||
desc
|
||||
|
@ -536,6 +544,39 @@ class DatabasePool:
|
|||
**kwargs
|
||||
"""
|
||||
|
||||
# Robustness check: ensure that none of the arguments are generators, since that
|
||||
# will fail if we have to repeat the transaction.
|
||||
# For now, we just log an error, and hope that it works on the first attempt.
|
||||
# TODO: raise an exception.
|
||||
for i, arg in enumerate(args):
|
||||
if inspect.isgenerator(arg):
|
||||
logger.error(
|
||||
"Programming error: generator passed to new_transaction as "
|
||||
"argument %i to function %s",
|
||||
i,
|
||||
func,
|
||||
)
|
||||
for name, val in kwargs.items():
|
||||
if inspect.isgenerator(val):
|
||||
logger.error(
|
||||
"Programming error: generator passed to new_transaction as "
|
||||
"argument %s to function %s",
|
||||
name,
|
||||
func,
|
||||
)
|
||||
# also check variables referenced in func's closure
|
||||
if inspect.isfunction(func):
|
||||
f = cast(types.FunctionType, func)
|
||||
if f.__closure__:
|
||||
for i, cell in enumerate(f.__closure__):
|
||||
if inspect.isgenerator(cell.cell_contents):
|
||||
logger.error(
|
||||
"Programming error: function %s references generator %s "
|
||||
"via its closure",
|
||||
f,
|
||||
f.__code__.co_freevars[i],
|
||||
)
|
||||
|
||||
start = monotonic_time()
|
||||
txn_id = self._TXN_ID
|
||||
|
||||
|
@ -1226,9 +1267,9 @@ class DatabasePool:
|
|||
self,
|
||||
table: str,
|
||||
key_names: Collection[str],
|
||||
key_values: Collection[Iterable[Any]],
|
||||
key_values: Collection[Collection[Any]],
|
||||
value_names: Collection[str],
|
||||
value_values: Iterable[Iterable[Any]],
|
||||
value_values: Collection[Collection[Any]],
|
||||
desc: str,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -1920,7 +1961,7 @@ class DatabasePool:
|
|||
self,
|
||||
table: str,
|
||||
column: str,
|
||||
iterable: Iterable[Any],
|
||||
iterable: Collection[Any],
|
||||
keyvalues: Dict[str, Any],
|
||||
desc: str,
|
||||
) -> int:
|
||||
|
@ -1931,7 +1972,8 @@ class DatabasePool:
|
|||
Args:
|
||||
table: string giving the table name
|
||||
column: column name to test for inclusion against `iterable`
|
||||
iterable: list
|
||||
iterable: list of values to match against `column`. NB cannot be a generator
|
||||
as it may be evaluated multiple times.
|
||||
keyvalues: dict of column names and values to select the rows with
|
||||
desc: description of the transaction, for logging and metrics
|
||||
|
||||
|
|
|
@ -269,6 +269,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||
"""
|
||||
# Add user entries to the table, updating the presence_stream_id column if the user already
|
||||
# exists in the table.
|
||||
presence_stream_id = self._presence_id_gen.get_current_token()
|
||||
await self.db_pool.simple_upsert_many(
|
||||
table="users_to_send_full_presence_to",
|
||||
key_names=("user_id",),
|
||||
|
@ -279,9 +280,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||
# devices at different times, each device will receive full presence once - when
|
||||
# the presence stream ID in their sync token is less than the one in the table
|
||||
# for their user ID.
|
||||
value_values=(
|
||||
(self._presence_id_gen.get_current_token(),) for _ in user_ids
|
||||
),
|
||||
value_values=[(presence_stream_id,) for _ in user_ids],
|
||||
desc="add_users_to_send_full_presence_to",
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue