Working version.

clokep/psycopg3
Patrick Cloke 2022-07-29 15:15:18 -04:00
parent f5ef7e13d7
commit 9f0ccbdbaf
6 changed files with 58 additions and 57 deletions

View File

@ -56,7 +56,7 @@ from synapse.logging.context import (
from synapse.metrics import register_threadpool from synapse.metrics import register_threadpool
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, PsycopgEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
from synapse.util.async_helpers import delay_cancellation from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -334,7 +334,8 @@ class LoggingTransaction:
def fetchone(self) -> Optional[Tuple]: def fetchone(self) -> Optional[Tuple]:
return self.txn.fetchone() return self.txn.fetchone()
def fetchmany(self, size: Optional[int] = None) -> List[Tuple]: def fetchmany(self, size: int = 0) -> List[Tuple]:
# XXX This can also be called with no arguments.
return self.txn.fetchmany(size=size) return self.txn.fetchmany(size=size)
def fetchall(self) -> List[Tuple]: def fetchall(self) -> List[Tuple]:
@ -400,6 +401,11 @@ class LoggingTransaction:
def _make_sql_one_line(self, sql: str) -> str: def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line" "Strip newlines out of SQL so that the loggers in the DB are on one line"
if isinstance(self.database_engine, PsycopgEngine):
import psycopg.sql
if isinstance(sql, psycopg.sql.Composed):
return sql.as_string(None)
return " ".join(line.strip() for line in sql.splitlines() if line.strip()) return " ".join(line.strip() for line in sql.splitlines() if line.strip())
def _do_execute( def _do_execute(
@ -440,7 +446,7 @@ class LoggingTransaction:
finally: finally:
secs = time.time() - start secs = time.time() - start
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs) sql_query_timer.labels(one_line_sql.split()[0]).observe(secs)
def close(self) -> None: def close(self) -> None:
self.txn.close() self.txn.close()

View File

@ -21,7 +21,7 @@ from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
# installed. To account for this, create dummy classes on import failure so we can # installed. To account for this, create dummy classes on import failure so we can
# still run `isinstance()` checks. # still run `isinstance()` checks.
def dummy_engine(name: str, module: str) -> BaseDatabaseEngine: def dummy_engine(name: str, module: str) -> BaseDatabaseEngine:
class Engine(BaseDatabaseEngine): # type: ignore[no-redef] class Engine(BaseDatabaseEngine):
def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc] def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc]
raise RuntimeError( raise RuntimeError(
f"Cannot create {name}Engine -- {module} module is not installed" f"Cannot create {name}Engine -- {module} module is not installed"
@ -33,17 +33,17 @@ def dummy_engine(name: str, module: str) -> BaseDatabaseEngine:
try: try:
from .postgres import PostgresEngine from .postgres import PostgresEngine
except ImportError: except ImportError:
PostgresEngine = dummy_engine("PostgresEngine", "psycopg2") PostgresEngine = dummy_engine("PostgresEngine", "psycopg2") # type: ignore[misc,assignment]
try: try:
from .psycopg import PsycopgEngine from .psycopg import PsycopgEngine
except ImportError: except ImportError:
PsycopgEngine = dummy_engine("PsycopgEngine", "psycopg") PsycopgEngine = dummy_engine("PsycopgEngine", "psycopg") # type: ignore[misc,assignment]
try: try:
from .sqlite import Sqlite3Engine from .sqlite import Sqlite3Engine
except ImportError: except ImportError:
Sqlite3Engine = dummy_engine("Sqlite3Engine", "sqlite3") Sqlite3Engine = dummy_engine("Sqlite3Engine", "sqlite3") # type: ignore[misc,assignment]
def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine: def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine:

View File

@ -15,7 +15,9 @@
import logging import logging
from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast
import psycopg2.extensions import psycopg
import psycopg.errors
import psycopg.sql
from synapse.storage.engines._base import ( from synapse.storage.engines._base import (
BaseDatabaseEngine, BaseDatabaseEngine,
@ -31,28 +33,26 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]): class PsycopgEngine(BaseDatabaseEngine[psycopg.Connection]):
def __init__(self, database_config: Mapping[str, Any]): def __init__(self, database_config: Mapping[str, Any]):
super().__init__(psycopg2, database_config) super().__init__(psycopg, database_config)
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) # psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
# actually want to use bytes than wrap it in `bytearray`. # actually want to use bytes than wrap it in `bytearray`.
def _disable_bytes_adapter(_: bytes) -> NoReturn: # def _disable_bytes_adapter(_: bytes) -> NoReturn:
raise Exception("Passing bytes to DB is disabled.") # raise Exception("Passing bytes to DB is disabled.")
psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter) # psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
self.synchronous_commit: bool = database_config.get("synchronous_commit", True) self.synchronous_commit: bool = database_config.get("synchronous_commit", True)
self._version: Optional[int] = None # unknown as yet self._version: Optional[int] = None # unknown as yet
self.isolation_level_map: Mapping[int, int] = { self.isolation_level_map: Mapping[int, psycopg.IsolationLevel] = {
IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED, IsolationLevel.READ_COMMITTED: psycopg.IsolationLevel.READ_COMMITTED,
IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ, IsolationLevel.REPEATABLE_READ: psycopg.IsolationLevel.REPEATABLE_READ,
IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE, IsolationLevel.SERIALIZABLE: psycopg.IsolationLevel.SERIALIZABLE,
} }
self.default_isolation_level = ( self.default_isolation_level = psycopg.IsolationLevel.REPEATABLE_READ
psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
self.config = database_config self.config = database_config
@property @property
@ -68,14 +68,14 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
def check_database( def check_database(
self, self,
db_conn: psycopg2.extensions.connection, db_conn: psycopg.Connection,
allow_outdated_version: bool = False, allow_outdated_version: bool = False,
) -> None: ) -> None:
# Get the version of PostgreSQL that we're using. As per the psycopg2 # Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and # docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them # revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105 # together. For example, version 8.1.5 will be returned as 80105
self._version = cast(int, db_conn.server_version) self._version = cast(int, db_conn.info.server_version)
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False) allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version? # Are we on a supported PostgreSQL version?
@ -140,6 +140,9 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
) )
def convert_param_style(self, sql: str) -> str: def convert_param_style(self, sql: str) -> str:
if isinstance(sql, psycopg.sql.Composed):
return sql
return sql.replace("?", "%s") return sql.replace("?", "%s")
def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
@ -186,14 +189,14 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
return True return True
def is_deadlock(self, error: Exception) -> bool: def is_deadlock(self, error: Exception) -> bool:
if isinstance(error, psycopg2.DatabaseError): if isinstance(error, psycopg.errors.Error):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
# "40001" serialization_failure # "40001" serialization_failure
# "40P01" deadlock_detected # "40P01" deadlock_detected
return error.pgcode in ["40001", "40P01"] return error.sqlstate in ["40001", "40P01"]
return False return False
def is_connection_closed(self, conn: psycopg2.extensions.connection) -> bool: def is_connection_closed(self, conn: psycopg.Connection) -> bool:
return bool(conn.closed) return bool(conn.closed)
def lock_table(self, txn: Cursor, table: str) -> None: def lock_table(self, txn: Cursor, table: str) -> None:
@ -213,19 +216,19 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
else: else:
return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
def in_transaction(self, conn: psycopg2.extensions.connection) -> bool: def in_transaction(self, conn: psycopg.Connection) -> bool:
return conn.status != psycopg2.extensions.STATUS_READY return conn.info.transaction_status != psycopg.pq.TransactionStatus.IDLE
def attempt_to_set_autocommit( def attempt_to_set_autocommit(
self, conn: psycopg2.extensions.connection, autocommit: bool self, conn: psycopg.Connection, autocommit: bool
) -> None: ) -> None:
return conn.set_session(autocommit=autocommit) conn.autocommit = autocommit
def attempt_to_set_isolation_level( def attempt_to_set_isolation_level(
self, conn: psycopg2.extensions.connection, isolation_level: Optional[int] self, conn: psycopg.Connection, isolation_level: Optional[int]
) -> None: ) -> None:
if isolation_level is None: if isolation_level is None:
isolation_level = self.default_isolation_level pg_isolation_level = self.default_isolation_level
else: else:
isolation_level = self.isolation_level_map[isolation_level] pg_isolation_level = self.isolation_level_map[isolation_level]
return conn.set_isolation_level(isolation_level) conn.isolation_level = pg_isolation_level

View File

@ -17,7 +17,7 @@
Adds a postgres SEQUENCE for generating application service transaction IDs. Adds a postgres SEQUENCE for generating application service transaction IDs.
""" """
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PsycopgEngine
def run_create(cur, database_engine, *args, **kwargs): def run_create(cur, database_engine, *args, **kwargs):
@ -38,7 +38,14 @@ def run_create(cur, database_engine, *args, **kwargs):
start_val = max(last_txn_max, txn_max) + 1 start_val = max(last_txn_max, txn_max) + 1
# XXX This is a hack.
sql = f"CREATE SEQUENCE application_services_txn_id_seq START WITH {start_val}"
args = ()
if isinstance(database_engine, PsycopgEngine):
import psycopg.sql
cur.execute( cur.execute(
"CREATE SEQUENCE application_services_txn_id_seq START WITH %s", psycopg.sql.SQL(sql).format(args)
(start_val,),
) )
else:
cur.execute(sql, args)

View File

@ -33,7 +33,7 @@ class Cursor(Protocol):
def fetchone(self) -> Optional[Tuple]: def fetchone(self) -> Optional[Tuple]:
... ...
def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]: def fetchmany(self, size: int = ...) -> List[Tuple]:
... ...
def fetchall(self) -> List[Tuple]: def fetchall(self) -> List[Tuple]:
@ -42,22 +42,7 @@ class Cursor(Protocol):
@property @property
def description( def description(
self, self,
) -> Optional[ ) -> Optional[Sequence[Any]]:
Sequence[
# Note that this is an approximate typing based on sqlite3 and other
# drivers, and may not be entirely accurate.
# FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description
Tuple[
str,
Optional[Any],
Optional[int],
Optional[int],
Optional[int],
Optional[int],
Optional[int],
]
]
]:
... ...
@property @property

View File

@ -83,11 +83,11 @@ def setupdb() -> None:
# Set up in the db # Set up in the db
db_conn = db_engine.module.connect( db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER, user=POSTGRES_USER,
host=POSTGRES_HOST, host=POSTGRES_HOST,
port=POSTGRES_PORT, port=POSTGRES_PORT,
password=POSTGRES_PASSWORD, password=POSTGRES_PASSWORD,
dbname=POSTGRES_BASE_DB,
) )
logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(logging_conn, db_engine, None) prepare_database(logging_conn, db_engine, None)