clokep/psycopg3
Patrick Cloke 2023-09-22 15:26:21 -04:00
parent edff9f7dca
commit 4a0dfb336f
8 changed files with 93 additions and 45 deletions

View File

@ -62,7 +62,7 @@ from synapse.storage.engines import (
BaseDatabaseEngine, BaseDatabaseEngine,
PostgresEngine, PostgresEngine,
PsycopgEngine, PsycopgEngine,
Sqlite3Engine, Sqlite3Engine, Psycopg2Engine,
) )
from synapse.storage.types import Connection, Cursor, SQLQueryParameters from synapse.storage.types import Connection, Cursor, SQLQueryParameters
from synapse.util.async_helpers import delay_cancellation from synapse.util.async_helpers import delay_cancellation
@ -389,7 +389,7 @@ class LoggingTransaction:
More efficient than `executemany` on PostgreSQL More efficient than `executemany` on PostgreSQL
""" """
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, Psycopg2Engine):
from psycopg2.extras import execute_batch from psycopg2.extras import execute_batch
# TODO: is it safe for values to be Iterable[Iterable[Any]] here? # TODO: is it safe for values to be Iterable[Iterable[Any]] here?
@ -398,6 +398,8 @@ class LoggingTransaction:
self._do_execute( self._do_execute(
lambda the_sql: execute_batch(self.txn, the_sql, args), sql lambda the_sql: execute_batch(self.txn, the_sql, args), sql
) )
# TODO Can psycopg3 do anything better?
else: else:
# TODO: is it safe for values to be Iterable[Iterable[Any]] here? # TODO: is it safe for values to be Iterable[Iterable[Any]] here?
# https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#sqlite3.Cursor.executemany # https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#sqlite3.Cursor.executemany
@ -422,7 +424,8 @@ class LoggingTransaction:
The `template` is the snippet to merge to every item in argslist to The `template` is the snippet to merge to every item in argslist to
compose the query. compose the query.
""" """
assert isinstance(self.database_engine, PostgresEngine) assert isinstance(self.database_engine, Psycopg2Engine)
from psycopg2.extras import execute_values from psycopg2.extras import execute_values
return self._do_execute( return self._do_execute(
@ -435,6 +438,14 @@ class LoggingTransaction:
values, values,
) )
def copy(
self,
sql: str, args: Iterable[Iterable[Any]]
) -> None:
with self.txn.copy(sql) as copy:
for record in args:
copy.write_row(record)
def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None: def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
self._do_execute(self.txn.execute, sql, parameters) self._do_execute(self.txn.execute, sql, parameters)
@ -1180,7 +1191,7 @@ class DatabasePool:
values: for each row, a list of values in the same order as `keys` values: for each row, a list of values in the same order as `keys`
""" """
if isinstance(txn.database_engine, PostgresEngine): if isinstance(txn.database_engine, Psycopg2Engine):
# We use `execute_values` as it can be a lot faster than `execute_batch`, # We use `execute_values` as it can be a lot faster than `execute_batch`,
# but it's only available on postgres. # but it's only available on postgres.
sql = "INSERT INTO %s (%s) VALUES ?" % ( sql = "INSERT INTO %s (%s) VALUES ?" % (
@ -1189,6 +1200,13 @@ class DatabasePool:
) )
txn.execute_values(sql, values, fetch=False) txn.execute_values(sql, values, fetch=False)
elif isinstance(txn.database_engine, PsycopgEngine):
sql = "COPY %s (%s) FROM STDIN" % (
table, ", ".join(k for k in keys),
)
txn.copy(sql, values)
else: else:
sql = "INSERT INTO %s (%s) VALUES(%s)" % ( sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table, table,
@ -1606,7 +1624,7 @@ class DatabasePool:
for x, y in zip(key_values, value_values): for x, y in zip(key_values, value_values):
args.append(tuple(x) + tuple(y)) args.append(tuple(x) + tuple(y))
if isinstance(txn.database_engine, PostgresEngine): if isinstance(txn.database_engine, Psycopg2Engine):
# We use `execute_values` as it can be a lot faster than `execute_batch`, # We use `execute_values` as it can be a lot faster than `execute_batch`,
# but it's only available on postgres. # but it's only available on postgres.
sql = "INSERT INTO %s (%s) VALUES ? ON CONFLICT (%s) DO %s" % ( sql = "INSERT INTO %s (%s) VALUES ? ON CONFLICT (%s) DO %s" % (
@ -2362,7 +2380,7 @@ class DatabasePool:
values: for each row, a list of values in the same order as `keys` values: for each row, a list of values in the same order as `keys`
""" """
if isinstance(txn.database_engine, PostgresEngine): if isinstance(txn.database_engine, Psycopg2Engine):
# We use `execute_values` as it can be a lot faster than `execute_batch`, # We use `execute_values` as it can be a lot faster than `execute_batch`,
# but it's only available on postgres. # but it's only available on postgres.
sql = "DELETE FROM %s WHERE (%s) IN (VALUES ?)" % ( sql = "DELETE FROM %s WHERE (%s) IN (VALUES ?)" % (

View File

@ -47,7 +47,7 @@ from synapse.storage.database import (
) )
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine, Psycopg2Engine
from synapse.types import JsonDict, StrCollection from synapse.types import JsonDict, StrCollection
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -321,7 +321,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
AND sequence_number <= max_seq AND sequence_number <= max_seq
""" """
if isinstance(self.database_engine, Psycopg2Engine):
rows = txn.execute_values(sql, chains.items()) rows = txn.execute_values(sql, chains.items())
else:
rows = txn.executemany(sql, chains.items())
results.update(r for r, in rows) results.update(r for r, in rows)
else: else:
# For SQLite we just fall back to doing a noddy for loop. # For SQLite we just fall back to doing a noddy for loop.
@ -597,7 +600,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
for chain_id, (min_no, max_no) in chain_to_gap.items() for chain_id, (min_no, max_no) in chain_to_gap.items()
] ]
if isinstance(self.database_engine, Psycopg2Engine):
rows = txn.execute_values(sql, args) rows = txn.execute_values(sql, args)
else:
rows = txn.executemany(sql, args)
result.update(r for r, in rows) result.update(r for r, in rows)
else: else:
# For SQLite we just fall back to doing a noddy for loop. # For SQLite we just fall back to doing a noddy for loop.

View File

@ -46,7 +46,7 @@ from synapse.storage.databases.main.stream import (
generate_pagination_bounds, generate_pagination_bounds,
generate_pagination_where_clause, generate_pagination_where_clause,
) )
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine, Psycopg2Engine
from synapse.types import JsonDict, StreamKeyType, StreamToken from synapse.types import JsonDict, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -139,7 +139,7 @@ class RelationsWorkerStore(SQLBaseStore):
ON CONFLICT (room_id, thread_id) ON CONFLICT (room_id, thread_id)
DO NOTHING DO NOTHING
""" """
if isinstance(txn.database_engine, PostgresEngine): if isinstance(txn.database_engine, Psycopg2Engine):
txn.execute_values(sql % ("?",), rows, fetch=False) txn.execute_values(sql % ("?",), rows, fetch=False)
else: else:
txn.execute_batch(sql % ("(?, ?, ?, ?, ?)",), rows) txn.execute_batch(sql % ("(?, ?, ?, ?, ?)",), rows)

View File

@ -54,7 +54,8 @@ from synapse.storage.database import (
) )
from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine, PsycopgEngine, \
Psycopg2Engine
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
UserID, UserID,
@ -717,19 +718,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# We weight the localpart most highly, then display name and finally # We weight the localpart most highly, then display name and finally
# server name # server name
template = """
(
%s,
setweight(to_tsvector('simple', %s), 'A')
|| setweight(to_tsvector('simple', %s), 'D')
|| setweight(to_tsvector('simple', COALESCE(%s, '')), 'B')
)
"""
sql = """ sql = """
INSERT INTO user_directory_search(user_id, vector) INSERT INTO user_directory_search(user_id, vector)
VALUES ? ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector VALUES
(
?,
setweight(to_tsvector('simple', ?), 'A')
|| setweight(to_tsvector('simple', ?), 'D')
|| setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
)
ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
""" """
if isinstance(self.database_engine, Psycopg2Engine):
txn.execute_values( txn.execute_values(
sql, sql,
[ [
@ -743,9 +744,23 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) )
for p in profiles for p in profiles
], ],
template=template,
fetch=False, fetch=False,
) )
if isinstance(self.database_engine, PsycopgEngine):
txn.executemany(
sql,
[
(
p.user_id,
get_localpart_from_id(p.user_id),
get_domain_from_id(p.user_id),
_filter_text_for_index(p.display_name)
if p.display_name
else None,
)
for p in profiles
],
)
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
values = [] values = []
for p in profiles: for p in profiles:

View File

@ -164,8 +164,11 @@ class PostgresEngine(BaseDatabaseEngine[ConnectionType, CursorType], metaclass=a
# Abort really long-running statements and turn them into errors. # Abort really long-running statements and turn them into errors.
if self.statement_timeout is not None: if self.statement_timeout is not None:
# TODO Avoid a circular import, this needs to be abstracted.
if self.__class__.__name__ == "Psycopg2Engine":
cursor.execute("SET statement_timeout TO ?", (self.statement_timeout,))
else:
cursor.execute(sql.SQL("SET statement_timeout TO {}").format(self.statement_timeout)) cursor.execute(sql.SQL("SET statement_timeout TO {}").format(self.statement_timeout))
#cursor.execute("SELECT set_config( 'statement_timeout', ?, false)", (self.statement_timeout,))
cursor.close() cursor.close()
db_conn.commit() db_conn.commit()

View File

@ -19,6 +19,8 @@ import psycopg
import psycopg.errors import psycopg.errors
import psycopg.sql import psycopg.sql
from twisted.enterprise.adbapi import Connection as TxConnection
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import ( from synapse.storage.engines._base import (
IsolationLevel, IsolationLevel,
@ -73,6 +75,10 @@ class PsycopgEngine(PostgresEngine[psycopg.Connection, psycopg.Cursor]):
def attempt_to_set_autocommit( def attempt_to_set_autocommit(
self, conn: psycopg.Connection, autocommit: bool self, conn: psycopg.Connection, autocommit: bool
) -> None: ) -> None:
# Sometimes this gets called with a Twisted connection instead, unwrap
# it because it doesn't support __setattr__.
if isinstance(conn, TxConnection):
conn = conn._connection
conn.autocommit = autocommit conn.autocommit = autocommit
def attempt_to_set_isolation_level( def attempt_to_set_isolation_level(

View File

@ -811,7 +811,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"agg": "GREATEST" if self._positive else "LEAST", "agg": "GREATEST" if self._positive else "LEAST",
} }
pos = (self.get_current_token_for_writer(self._instance_name),) pos = self.get_current_token_for_writer(self._instance_name)
txn.execute(sql, (self._stream_name, self._instance_name, pos)) txn.execute(sql, (self._stream_name, self._instance_name, pos))

View File

@ -960,11 +960,11 @@ def setup_test_homeserver(
test_db = "synapse_test_%s" % uuid.uuid4().hex test_db = "synapse_test_%s" % uuid.uuid4().hex
if USE_POSTGRES_FOR_TESTS == "psycopg": if USE_POSTGRES_FOR_TESTS == "psycopg":
name = "psycopg" db_type = "psycopg"
else: else:
name = "psycopg2" db_type = "psycopg2"
database_config = { database_config = {
"name": name, "name": db_type,
"args": { "args": {
"dbname": test_db, "dbname": test_db,
"host": POSTGRES_HOST, "host": POSTGRES_HOST,