clokep/psycopg3
Patrick Cloke 2023-09-29 13:44:29 -04:00
parent 208a5944a6
commit 29492b7e85
9 changed files with 37 additions and 28 deletions

View File

@ -293,7 +293,7 @@ all = [
# matrix-synapse-ldap3 # matrix-synapse-ldap3
"matrix-synapse-ldap3", "matrix-synapse-ldap3",
# postgres # postgres
"psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "psycopg",
# saml2 # saml2
"pysaml2", "pysaml2",
# oidc and jwt # oidc and jwt

View File

@ -60,9 +60,9 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import ( from synapse.storage.engines import (
BaseDatabaseEngine, BaseDatabaseEngine,
PostgresEngine, Psycopg2Engine,
PsycopgEngine, PsycopgEngine,
Sqlite3Engine, Psycopg2Engine, Sqlite3Engine,
) )
from synapse.storage.types import Connection, Cursor, SQLQueryParameters from synapse.storage.types import Connection, Cursor, SQLQueryParameters
from synapse.util.async_helpers import delay_cancellation from synapse.util.async_helpers import delay_cancellation
@ -426,18 +426,14 @@ class LoggingTransaction:
values, values,
) )
def copy_write( def copy_write(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
self,
sql: str, args: Iterable[Iterable[Any]]
) -> None:
# TODO use _do_execute # TODO use _do_execute
with self.txn.copy(sql) as copy: with self.txn.copy(sql) as copy:
for record in args: for record in args:
copy.write_row(record) copy.write_row(record)
def copy_read( def copy_read(
self, self, sql: str, args: Iterable[Iterable[Any]]
sql: str, args: Iterable[Iterable[Any]]
) -> Iterable[Tuple[Any, ...]]: ) -> Iterable[Tuple[Any, ...]]:
# TODO use _do_execute # TODO use _do_execute
sql = self.database_engine.convert_param_style(sql) sql = self.database_engine.convert_param_style(sql)
@ -466,6 +462,7 @@ class LoggingTransaction:
"Strip newlines out of SQL so that the loggers in the DB are on one line" "Strip newlines out of SQL so that the loggers in the DB are on one line"
if isinstance(self.database_engine, PsycopgEngine): if isinstance(self.database_engine, PsycopgEngine):
import psycopg.sql import psycopg.sql
if isinstance(sql, psycopg.sql.Composed): if isinstance(sql, psycopg.sql.Composed):
return sql.as_string(None) return sql.as_string(None)
@ -1201,7 +1198,8 @@ class DatabasePool:
elif isinstance(txn.database_engine, PsycopgEngine): elif isinstance(txn.database_engine, PsycopgEngine):
sql = "COPY %s (%s) FROM STDIN" % ( sql = "COPY %s (%s) FROM STDIN" % (
table, ", ".join(k for k in keys), table,
", ".join(k for k in keys),
) )
txn.copy_write(sql, values) txn.copy_write(sql, values)

View File

@ -47,7 +47,7 @@ from synapse.storage.database import (
) )
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine, Psycopg2Engine from synapse.storage.engines import PostgresEngine, Psycopg2Engine, Sqlite3Engine
from synapse.types import JsonDict, StrCollection from synapse.types import JsonDict, StrCollection
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -332,9 +332,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
AND sequence_number <= max_seq AND sequence_number <= max_seq
) )
TO STDOUT TO STDOUT
""" % (", ".join("(?, ?)" for _ in chains)) """ % (
", ".join("(?, ?)" for _ in chains)
)
# Flatten the arguments. # Flatten the arguments.
rows = txn.copy_read(sql, list(itertools.chain.from_iterable(chains.items()))) rows = txn.copy_read(
sql, list(itertools.chain.from_iterable(chains.items()))
)
results.update(r for r, in rows) results.update(r for r, in rows)
else: else:
# For SQLite we just fall back to doing a noddy for loop. # For SQLite we just fall back to doing a noddy for loop.
@ -622,7 +626,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
AND min_seq < sequence_number AND sequence_number <= max_seq AND min_seq < sequence_number AND sequence_number <= max_seq
) )
TO STDOUT TO STDOUT
""" % (", ".join("(?, ?, ?)" for _ in args)) """ % (
", ".join("(?, ?, ?)" for _ in args)
)
# Flatten the arguments. # Flatten the arguments.
rows = txn.copy_read(sql, list(itertools.chain.from_iterable(args))) rows = txn.copy_read(sql, list(itertools.chain.from_iterable(args)))
result.update(r for r, in rows) result.update(r for r, in rows)

View File

@ -96,10 +96,10 @@ from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
PostgresEngine,
) )
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached

View File

@ -54,8 +54,12 @@ from synapse.storage.database import (
) )
from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine, PsycopgEngine, \ from synapse.storage.engines import (
Psycopg2Engine PostgresEngine,
Psycopg2Engine,
PsycopgEngine,
Sqlite3Engine,
)
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
UserID, UserID,

View File

@ -33,7 +33,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PostgresEngine(BaseDatabaseEngine[ConnectionType, CursorType], metaclass=abc.ABCMeta): class PostgresEngine(
BaseDatabaseEngine[ConnectionType, CursorType], metaclass=abc.ABCMeta
):
isolation_level_map: Mapping[int, int] isolation_level_map: Mapping[int, int]
default_isolation_level: int default_isolation_level: int
OperationalError: Type[Exception] OperationalError: Type[Exception]
@ -168,7 +170,11 @@ class PostgresEngine(BaseDatabaseEngine[ConnectionType, CursorType], metaclass=a
if self.__class__.__name__ == "Psycopg2Engine": if self.__class__.__name__ == "Psycopg2Engine":
cursor.execute("SET statement_timeout TO ?", (self.statement_timeout,)) cursor.execute("SET statement_timeout TO ?", (self.statement_timeout,))
else: else:
cursor.execute(sql.SQL("SET statement_timeout TO {}").format(self.statement_timeout)) cursor.execute(
sql.SQL("SET statement_timeout TO {}").format(
self.statement_timeout
)
)
cursor.close() cursor.close()
db_conn.commit() db_conn.commit()

View File

@ -22,9 +22,7 @@ import psycopg.sql
from twisted.enterprise.adbapi import Connection as TxConnection from twisted.enterprise.adbapi import Connection as TxConnection
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import ( from synapse.storage.engines._base import IsolationLevel
IsolationLevel,
)
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass

View File

@ -18,9 +18,7 @@ from typing import TYPE_CHECKING, Any, Mapping, Optional
import psycopg2.extensions import psycopg2.extensions
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import ( from synapse.storage.engines._base import IsolationLevel
IsolationLevel,
)
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass

View File

@ -47,8 +47,7 @@ def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) ->
if isinstance(database_engine, PsycopgEngine): if isinstance(database_engine, PsycopgEngine):
import psycopg.sql import psycopg.sql
cur.execute(
psycopg.sql.SQL(sql).format(args) cur.execute(psycopg.sql.SQL(sql).format(args))
)
else: else:
cur.execute(sql, args) cur.execute(sql, args)