Don't use separate copy_read method.

clokep/psycopg3
Patrick Cloke 2023-09-29 14:24:54 -04:00
parent a072285e9d
commit a280d117dc
2 changed files with 60 additions and 83 deletions

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import itertools
import logging
import time
import types
@ -62,7 +63,7 @@ from synapse.storage.engines import (
BaseDatabaseEngine,
Psycopg2Engine,
PsycopgEngine,
Sqlite3Engine,
Sqlite3Engine, PostgresEngine,
)
from synapse.storage.types import Connection, Cursor, SQLQueryParameters
from synapse.util.async_helpers import delay_cancellation
@ -399,7 +400,7 @@ class LoggingTransaction:
def execute_values(
self,
sql: str,
values: Iterable[Iterable[Any]],
values: Sequence[Sequence[Any]],
template: Optional[str] = None,
fetch: bool = True,
) -> List[Tuple]:
@ -412,19 +413,43 @@ class LoggingTransaction:
The `template` is the snippet to merge to every item in argslist to
compose the query.
"""
assert isinstance(self.database_engine, Psycopg2Engine)
assert isinstance(self.database_engine, PostgresEngine)
from psycopg2.extras import execute_values
if isinstance(self.database_engine, Psycopg2Engine):
return self._do_execute(
# TODO: is it safe for values to be Iterable[Iterable[Any]] here?
# https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence]
lambda the_sql, the_values: execute_values(
self.txn, the_sql, the_values, template=template, fetch=fetch
),
sql,
values,
)
from psycopg2.extras import execute_values
return self._do_execute(
# TODO: is it safe for values to be Iterable[Iterable[Any]] here?
# https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence]
lambda the_sql, the_values: execute_values(
self.txn, the_sql, the_values, template=template, fetch=fetch
),
sql,
values,
)
else:
# We use fetch = False to mean a writable query. You *might* be able
# to morph that into a COPY (...) FROM STDIN, but it isn't worth the
# effort for the few places we set fetch = False.
assert fetch is True
# execute_values requires a single replacement, but we need to expand it
# for COPY. This assumes all inner sequences are the same length.
value_str = "(" + ", ".join("?" for _ in next(iter(values))) + ")"
sql = sql.replace("?", ", ".join(value_str for _ in values))
# Wrap the SQL in the COPY statement.
sql = f"COPY ({sql}) TO STDOUT"
def f(
the_sql: str, the_args: Sequence[Sequence[Any]]
) -> Iterable[Tuple[Any, ...]]:
with self.txn.copy(the_sql, the_args) as copy:
yield from copy.rows()
# Flatten the values.
return self._do_execute(f, sql, list(itertools.chain.from_iterable(values)))
def copy_write(
self, sql: str, args: Iterable[Any], values: Iterable[Iterable[Any]]
@ -441,20 +466,6 @@ class LoggingTransaction:
self._do_execute(f, sql, args, values)
def copy_read(
self, sql: str, args: Iterable[Iterable[Any]]
) -> Iterable[Tuple[Any, ...]]:
"""Corresponds to a PostgreSQL COPY (...) TO STDOUT call."""
assert isinstance(self.database_engine, PsycopgEngine)
def f(
the_sql: str, the_args: Iterable[Iterable[Any]]
) -> Iterable[Tuple[Any, ...]]:
with self.txn.copy(the_sql, the_args) as copy:
yield from copy.rows()
return self._do_execute(f, sql, args)
def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
self._do_execute(self.txn.execute, sql, parameters)
@ -1187,7 +1198,7 @@ class DatabasePool:
txn: LoggingTransaction,
table: str,
keys: Collection[str],
values: Iterable[Iterable[Any]],
values: Sequence[Sequence[Any]],
) -> None:
"""Executes an INSERT query on the named table.

View File

@ -311,34 +311,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
results = set()
if isinstance(self.database_engine, PostgresEngine):
if isinstance(self.database_engine, Psycopg2Engine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
WHERE
c.chain_id = l.chain_id
AND sequence_number <= max_seq
"""
rows = txn.execute_values(sql, chains.items())
else:
sql = """
COPY (
SELECT event_id
FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, max_seq)
WHERE
c.chain_id = l.chain_id
AND sequence_number <= max_seq
)
TO STDOUT
""" % (
", ".join("(?, ?)" for _ in chains)
)
# Flatten the arguments.
rows = txn.copy_read(
sql, list(itertools.chain.from_iterable(chains.items()))
)
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
WHERE
c.chain_id = l.chain_id
AND sequence_number <= max_seq
"""
rows = txn.execute_values(sql, chains.items())
results.update(r for r, in rows)
else:
# For SQLite we just fall back to doing a noddy for loop.
@ -599,38 +581,22 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return result
if isinstance(self.database_engine, PostgresEngine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
WHERE
c.chain_id = l.chain_id
AND min_seq < sequence_number AND sequence_number <= max_seq
"""
args = [
(chain_id, min_no, max_no)
for chain_id, (min_no, max_no) in chain_to_gap.items()
]
if isinstance(self.database_engine, Psycopg2Engine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
WHERE
c.chain_id = l.chain_id
AND min_seq < sequence_number AND sequence_number <= max_seq
"""
rows = txn.execute_values(sql, args)
else:
sql = """
COPY (
SELECT event_id
FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, min_seq, max_seq)
WHERE
c.chain_id = l.chain_id
AND min_seq < sequence_number AND sequence_number <= max_seq
)
TO STDOUT
""" % (
", ".join("(?, ?, ?)" for _ in args)
)
# Flatten the arguments.
rows = txn.copy_read(sql, list(itertools.chain.from_iterable(args)))
rows = txn.execute_values(sql, args)
result.update(r for r, in rows)
else:
# For SQLite we just fall back to doing a noddy for loop.