clokep/psycopg3
Patrick Cloke 2023-09-29 06:33:26 -04:00
parent 856faa8fce
commit 208a5944a6
3 changed files with 83 additions and 36 deletions

View File

@ -426,14 +426,24 @@ class LoggingTransaction:
values,
)
def copy(
def copy_write(
self,
sql: str, args: Iterable[Iterable[Any]]
) -> None:
# TODO use _do_execute
with self.txn.copy(sql) as copy:
for record in args:
copy.write_row(record)
def copy_read(
self,
sql: str, args: Iterable[Iterable[Any]]
) -> Iterable[Tuple[Any, ...]]:
# TODO use _do_execute
sql = self.database_engine.convert_param_style(sql)
with self.txn.copy(sql, args) as copy:
yield from copy.rows()
def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
self._do_execute(self.txn.execute, sql, parameters)
@ -1193,7 +1203,7 @@ class DatabasePool:
sql = "COPY %s (%s) FROM STDIN" % (
table, ", ".join(k for k in keys),
)
txn.copy(sql, values)
txn.copy_write(sql, values)
else:
sql = "INSERT INTO %s (%s) VALUES(%s)" % (

View File

@ -311,20 +311,30 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
results = set()
if isinstance(self.database_engine, PostgresEngine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
WHERE
c.chain_id = l.chain_id
AND sequence_number <= max_seq
"""
if isinstance(self.database_engine, Psycopg2Engine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
WHERE
c.chain_id = l.chain_id
AND sequence_number <= max_seq
"""
rows = txn.execute_values(sql, chains.items())
else:
rows = txn.executemany(sql, chains.items())
sql = """
COPY (
SELECT event_id
FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, max_seq)
WHERE
c.chain_id = l.chain_id
AND sequence_number <= max_seq
)
TO STDOUT
""" % (", ".join("(?, ?)" for _ in chains))
# Flatten the arguments.
rows = txn.copy_read(sql, list(itertools.chain.from_iterable(chains.items())))
results.update(r for r, in rows)
else:
# For SQLite we just fall back to doing a noddy for loop.
@ -585,25 +595,36 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return result
if isinstance(self.database_engine, PostgresEngine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
WHERE
c.chain_id = l.chain_id
AND min_seq < sequence_number AND sequence_number <= max_seq
"""
args = [
(chain_id, min_no, max_no)
for chain_id, (min_no, max_no) in chain_to_gap.items()
]
if isinstance(self.database_engine, Psycopg2Engine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
WHERE
c.chain_id = l.chain_id
AND min_seq < sequence_number AND sequence_number <= max_seq
"""
rows = txn.execute_values(sql, args)
else:
rows = txn.executemany(sql, args)
sql = """
COPY (
SELECT event_id
FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, min_seq, max_seq)
WHERE
c.chain_id = l.chain_id
AND min_seq < sequence_number AND sequence_number <= max_seq
)
TO STDOUT
""" % (", ".join("(?, ?, ?)" for _ in args))
# Flatten the arguments.
rows = txn.copy_read(sql, list(itertools.chain.from_iterable(args)))
result.update(r for r, in rows)
else:
# For SQLite we just fall back to doing a noddy for loop.

View File

@ -719,18 +719,21 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# We weight the localpart most highly, then display name and finally
# server name
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES
(
?,
setweight(to_tsvector('simple', ?), 'A')
|| setweight(to_tsvector('simple', ?), 'D')
|| setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
)
ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
"""
if isinstance(self.database_engine, Psycopg2Engine):
template = """
(
%s,
setweight(to_tsvector('simple', %s), 'A')
|| setweight(to_tsvector('simple', %s), 'D')
|| setweight(to_tsvector('simple', COALESCE(%s, '')), 'B')
)
"""
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES ? ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
"""
txn.execute_values(
sql,
[
@ -744,9 +747,22 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
for p in profiles
],
template=template,
fetch=False,
)
if isinstance(self.database_engine, PsycopgEngine):
elif isinstance(self.database_engine, PsycopgEngine):
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES
(
?,
setweight(to_tsvector('simple', ?), 'A')
|| setweight(to_tsvector('simple', ?), 'D')
|| setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
)
ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
"""
txn.executemany(
sql,
[