Support providing an index predicate for upserts. (#13822)

This is useful to upsert against a table which has a unique
partial index while avoiding conflicts.
pull/13826/head
Patrick Cloke 2022-09-15 14:28:48 -04:00 committed by GitHub
parent 742f9f9d78
commit b2b0c85279
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 7 deletions

1
changelog.d/13822.misc Normal file
View File

@ -0,0 +1 @@
Support providing an index predicate clause when doing upserts.

View File

@ -533,6 +533,7 @@ class BackgroundUpdater:
index_name: name of index to add index_name: name of index to add
table: table to add index to table: table to add index to
columns: columns/expressions to include in index columns: columns/expressions to include in index
where_clause: A WHERE clause to specify a partial unique index.
unique: true to make a UNIQUE index unique: true to make a UNIQUE index
psql_only: true to only create this index on psql databases (useful psql_only: true to only create this index on psql databases (useful
for virtual sqlite tables) for virtual sqlite tables)

View File

@ -1191,6 +1191,7 @@ class DatabasePool:
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
values: Dict[str, Any], values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None, insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
lock: bool = True, lock: bool = True,
) -> bool: ) -> bool:
""" """
@ -1203,6 +1204,7 @@ class DatabasePool:
keyvalues: The unique key tables and their new values keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
lock: True to lock the table when doing the upsert. Unused when performing lock: True to lock the table when doing the upsert. Unused when performing
a native upsert. a native upsert.
Returns: Returns:
@ -1213,7 +1215,12 @@ class DatabasePool:
if table not in self._unsafe_to_upsert_tables: if table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_txn_native_upsert( return self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values txn,
table,
keyvalues,
values,
insertion_values=insertion_values,
where_clause=where_clause,
) )
else: else:
return self.simple_upsert_txn_emulated( return self.simple_upsert_txn_emulated(
@ -1222,6 +1229,7 @@ class DatabasePool:
keyvalues, keyvalues,
values, values,
insertion_values=insertion_values, insertion_values=insertion_values,
where_clause=where_clause,
lock=lock, lock=lock,
) )
@ -1232,6 +1240,7 @@ class DatabasePool:
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
values: Dict[str, Any], values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None, insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
lock: bool = True, lock: bool = True,
) -> bool: ) -> bool:
""" """
@ -1240,6 +1249,7 @@ class DatabasePool:
keyvalues: The unique key tables and their new values keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
lock: True to lock the table when doing the upsert. lock: True to lock the table when doing the upsert.
Returns: Returns:
Returns True if a row was inserted or updated (i.e. if `values` is Returns True if a row was inserted or updated (i.e. if `values` is
@ -1259,14 +1269,17 @@ class DatabasePool:
else: else:
return "%s = ?" % (key,) return "%s = ?" % (key,)
# Generate a where clause of each keyvalue and optionally the provided
# index predicate.
where = [_getwhere(k) for k in keyvalues]
if where_clause:
where.append(where_clause)
if not values: if not values:
# If `values` is empty, then all of the values we care about are in # If `values` is empty, then all of the values we care about are in
# the unique key, so there is nothing to UPDATE. We can just do a # the unique key, so there is nothing to UPDATE. We can just do a
# SELECT instead to see if it exists. # SELECT instead to see if it exists.
sql = "SELECT 1 FROM %s WHERE %s" % ( sql = "SELECT 1 FROM %s WHERE %s" % (table, " AND ".join(where))
table,
" AND ".join(_getwhere(k) for k in keyvalues),
)
sqlargs = list(keyvalues.values()) sqlargs = list(keyvalues.values())
txn.execute(sql, sqlargs) txn.execute(sql, sqlargs)
if txn.fetchall(): if txn.fetchall():
@ -1277,7 +1290,7 @@ class DatabasePool:
sql = "UPDATE %s SET %s WHERE %s" % ( sql = "UPDATE %s SET %s WHERE %s" % (
table, table,
", ".join("%s = ?" % (k,) for k in values), ", ".join("%s = ?" % (k,) for k in values),
" AND ".join(_getwhere(k) for k in keyvalues), " AND ".join(where),
) )
sqlargs = list(values.values()) + list(keyvalues.values()) sqlargs = list(values.values()) + list(keyvalues.values())
@ -1307,6 +1320,7 @@ class DatabasePool:
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
values: Dict[str, Any], values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None, insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
) -> bool: ) -> bool:
""" """
Use the native UPSERT functionality in PostgreSQL. Use the native UPSERT functionality in PostgreSQL.
@ -1316,6 +1330,7 @@ class DatabasePool:
keyvalues: The unique key tables and their new values keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
Returns: Returns:
Returns True if a row was inserted or updated (i.e. if `values` is Returns True if a row was inserted or updated (i.e. if `values` is
@ -1331,11 +1346,12 @@ class DatabasePool:
allvalues.update(values) allvalues.update(values)
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % (
table, table,
", ".join(k for k in allvalues), ", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues), ", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues), ", ".join(k for k in keyvalues),
f"WHERE {where_clause}" if where_clause else "",
latter, latter,
) )
txn.execute(sql, list(allvalues.values())) txn.execute(sql, list(allvalues.values()))