Working version.

clokep/psycopg3
Patrick Cloke 2022-07-29 15:15:18 -04:00
parent f5ef7e13d7
commit 9f0ccbdbaf
6 changed files with 58 additions and 57 deletions

View File

@ -56,7 +56,7 @@ from synapse.logging.context import (
from synapse.metrics import register_threadpool
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, PsycopgEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
@ -334,7 +334,8 @@ class LoggingTransaction:
def fetchone(self) -> Optional[Tuple]:
return self.txn.fetchone()
def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
def fetchmany(self, size: int = 0) -> List[Tuple]:
# XXX This can also be called with no arguments.
return self.txn.fetchmany(size=size)
def fetchall(self) -> List[Tuple]:
@ -400,6 +401,11 @@ class LoggingTransaction:
def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
if isinstance(self.database_engine, PsycopgEngine):
import psycopg.sql
if isinstance(sql, psycopg.sql.Composed):
return sql.as_string(None)
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
def _do_execute(
@ -440,7 +446,7 @@ class LoggingTransaction:
finally:
secs = time.time() - start
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs)
sql_query_timer.labels(one_line_sql.split()[0]).observe(secs)
def close(self) -> None:
self.txn.close()

View File

@ -21,7 +21,7 @@ from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
# installed. To account for this, create dummy classes on import failure so we can
# still run `isinstance()` checks.
def dummy_engine(name: str, module: str) -> BaseDatabaseEngine:
class Engine(BaseDatabaseEngine): # type: ignore[no-redef]
class Engine(BaseDatabaseEngine):
def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc]
raise RuntimeError(
f"Cannot create {name}Engine -- {module} module is not installed"
@ -33,17 +33,17 @@ def dummy_engine(name: str, module: str) -> BaseDatabaseEngine:
try:
from .postgres import PostgresEngine
except ImportError:
PostgresEngine = dummy_engine("PostgresEngine", "psycopg2")
PostgresEngine = dummy_engine("PostgresEngine", "psycopg2") # type: ignore[misc,assignment]
try:
from .psycopg import PsycopgEngine
except ImportError:
PsycopgEngine = dummy_engine("PsycopgEngine", "psycopg")
PsycopgEngine = dummy_engine("PsycopgEngine", "psycopg") # type: ignore[misc,assignment]
try:
from .sqlite import Sqlite3Engine
except ImportError:
Sqlite3Engine = dummy_engine("Sqlite3Engine", "sqlite3")
Sqlite3Engine = dummy_engine("Sqlite3Engine", "sqlite3") # type: ignore[misc,assignment]
def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine:

View File

@ -15,7 +15,9 @@
import logging
from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast
import psycopg2.extensions
import psycopg
import psycopg.errors
import psycopg.sql
from synapse.storage.engines._base import (
BaseDatabaseEngine,
@ -31,28 +33,26 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
class PsycopgEngine(BaseDatabaseEngine[psycopg.Connection]):
def __init__(self, database_config: Mapping[str, Any]):
super().__init__(psycopg2, database_config)
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
super().__init__(psycopg, database_config)
# psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
# actually want to use bytes than wrap it in `bytearray`.
def _disable_bytes_adapter(_: bytes) -> NoReturn:
raise Exception("Passing bytes to DB is disabled.")
# def _disable_bytes_adapter(_: bytes) -> NoReturn:
# raise Exception("Passing bytes to DB is disabled.")
psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
# psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
self.synchronous_commit: bool = database_config.get("synchronous_commit", True)
self._version: Optional[int] = None # unknown as yet
self.isolation_level_map: Mapping[int, int] = {
IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED,
IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE,
self.isolation_level_map: Mapping[int, psycopg.IsolationLevel] = {
IsolationLevel.READ_COMMITTED: psycopg.IsolationLevel.READ_COMMITTED,
IsolationLevel.REPEATABLE_READ: psycopg.IsolationLevel.REPEATABLE_READ,
IsolationLevel.SERIALIZABLE: psycopg.IsolationLevel.SERIALIZABLE,
}
self.default_isolation_level = (
psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
self.default_isolation_level = psycopg.IsolationLevel.REPEATABLE_READ
self.config = database_config
@property
@ -68,14 +68,14 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
def check_database(
self,
db_conn: psycopg2.extensions.connection,
db_conn: psycopg.Connection,
allow_outdated_version: bool = False,
) -> None:
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = cast(int, db_conn.server_version)
self._version = cast(int, db_conn.info.server_version)
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version?
@ -140,6 +140,9 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
)
def convert_param_style(self, sql: str) -> str:
if isinstance(sql, psycopg.sql.Composed):
return sql
return sql.replace("?", "%s")
def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
@ -186,14 +189,14 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
return True
def is_deadlock(self, error: Exception) -> bool:
if isinstance(error, psycopg2.DatabaseError):
if isinstance(error, psycopg.errors.Error):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
# "40001" serialization_failure
# "40P01" deadlock_detected
return error.pgcode in ["40001", "40P01"]
return error.sqlstate in ["40001", "40P01"]
return False
def is_connection_closed(self, conn: psycopg2.extensions.connection) -> bool:
def is_connection_closed(self, conn: psycopg.Connection) -> bool:
return bool(conn.closed)
def lock_table(self, txn: Cursor, table: str) -> None:
@ -213,19 +216,19 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
else:
return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
def in_transaction(self, conn: psycopg2.extensions.connection) -> bool:
return conn.status != psycopg2.extensions.STATUS_READY
def in_transaction(self, conn: psycopg.Connection) -> bool:
return conn.info.transaction_status != psycopg.pq.TransactionStatus.IDLE
def attempt_to_set_autocommit(
self, conn: psycopg2.extensions.connection, autocommit: bool
self, conn: psycopg.Connection, autocommit: bool
) -> None:
return conn.set_session(autocommit=autocommit)
conn.autocommit = autocommit
def attempt_to_set_isolation_level(
self, conn: psycopg2.extensions.connection, isolation_level: Optional[int]
self, conn: psycopg.Connection, isolation_level: Optional[int]
) -> None:
if isolation_level is None:
isolation_level = self.default_isolation_level
pg_isolation_level = self.default_isolation_level
else:
isolation_level = self.isolation_level_map[isolation_level]
return conn.set_isolation_level(isolation_level)
pg_isolation_level = self.isolation_level_map[isolation_level]
conn.isolation_level = pg_isolation_level

View File

@ -17,7 +17,7 @@
Adds a postgres SEQUENCE for generating application service transaction IDs.
"""
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines import PsycopgEngine
def run_create(cur, database_engine, *args, **kwargs):
@ -38,7 +38,14 @@ def run_create(cur, database_engine, *args, **kwargs):
start_val = max(last_txn_max, txn_max) + 1
cur.execute(
"CREATE SEQUENCE application_services_txn_id_seq START WITH %s",
(start_val,),
)
# XXX This is a hack.
sql = f"CREATE SEQUENCE application_services_txn_id_seq START WITH {start_val}"
args = ()
if isinstance(database_engine, PsycopgEngine):
import psycopg.sql
cur.execute(
psycopg.sql.SQL(sql).format(args)
)
else:
cur.execute(sql, args)

View File

@ -33,7 +33,7 @@ class Cursor(Protocol):
def fetchone(self) -> Optional[Tuple]:
...
def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
def fetchmany(self, size: int = ...) -> List[Tuple]:
...
def fetchall(self) -> List[Tuple]:
@ -42,22 +42,7 @@ class Cursor(Protocol):
@property
def description(
self,
) -> Optional[
Sequence[
# Note that this is an approximate typing based on sqlite3 and other
# drivers, and may not be entirely accurate.
# FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description
Tuple[
str,
Optional[Any],
Optional[int],
Optional[int],
Optional[int],
Optional[int],
Optional[int],
]
]
]:
) -> Optional[Sequence[Any]]:
...
@property

View File

@ -83,11 +83,11 @@ def setupdb() -> None:
# Set up in the db
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
dbname=POSTGRES_BASE_DB,
)
logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(logging_conn, db_engine, None)