Working version.
parent
f5ef7e13d7
commit
9f0ccbdbaf
|
@ -56,7 +56,7 @@ from synapse.logging.context import (
|
|||
from synapse.metrics import register_threadpool
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.background_updates import BackgroundUpdater
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, PsycopgEngine, Sqlite3Engine
|
||||
from synapse.storage.types import Connection, Cursor
|
||||
from synapse.util.async_helpers import delay_cancellation
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
@ -334,7 +334,8 @@ class LoggingTransaction:
|
|||
def fetchone(self) -> Optional[Tuple]:
|
||||
return self.txn.fetchone()
|
||||
|
||||
def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
|
||||
def fetchmany(self, size: int = 0) -> List[Tuple]:
|
||||
# XXX This can also be called with no arguments.
|
||||
return self.txn.fetchmany(size=size)
|
||||
|
||||
def fetchall(self) -> List[Tuple]:
|
||||
|
@ -400,6 +401,11 @@ class LoggingTransaction:
|
|||
|
||||
def _make_sql_one_line(self, sql: str) -> str:
|
||||
"Strip newlines out of SQL so that the loggers in the DB are on one line"
|
||||
if isinstance(self.database_engine, PsycopgEngine):
|
||||
import psycopg.sql
|
||||
if isinstance(sql, psycopg.sql.Composed):
|
||||
return sql.as_string(None)
|
||||
|
||||
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
|
||||
|
||||
def _do_execute(
|
||||
|
@ -440,7 +446,7 @@ class LoggingTransaction:
|
|||
finally:
|
||||
secs = time.time() - start
|
||||
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
|
||||
sql_query_timer.labels(sql.split()[0]).observe(secs)
|
||||
sql_query_timer.labels(one_line_sql.split()[0]).observe(secs)
|
||||
|
||||
def close(self) -> None:
|
||||
self.txn.close()
|
||||
|
|
|
@ -21,7 +21,7 @@ from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
|
|||
# installed. To account for this, create dummy classes on import failure so we can
|
||||
# still run `isinstance()` checks.
|
||||
def dummy_engine(name: str, module: str) -> BaseDatabaseEngine:
|
||||
class Engine(BaseDatabaseEngine): # type: ignore[no-redef]
|
||||
class Engine(BaseDatabaseEngine):
|
||||
def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc]
|
||||
raise RuntimeError(
|
||||
f"Cannot create {name}Engine -- {module} module is not installed"
|
||||
|
@ -33,17 +33,17 @@ def dummy_engine(name: str, module: str) -> BaseDatabaseEngine:
|
|||
try:
|
||||
from .postgres import PostgresEngine
|
||||
except ImportError:
|
||||
PostgresEngine = dummy_engine("PostgresEngine", "psycopg2")
|
||||
PostgresEngine = dummy_engine("PostgresEngine", "psycopg2") # type: ignore[misc,assignment]
|
||||
|
||||
try:
|
||||
from .psycopg import PsycopgEngine
|
||||
except ImportError:
|
||||
PsycopgEngine = dummy_engine("PsycopgEngine", "psycopg")
|
||||
PsycopgEngine = dummy_engine("PsycopgEngine", "psycopg") # type: ignore[misc,assignment]
|
||||
|
||||
try:
|
||||
from .sqlite import Sqlite3Engine
|
||||
except ImportError:
|
||||
Sqlite3Engine = dummy_engine("Sqlite3Engine", "sqlite3")
|
||||
Sqlite3Engine = dummy_engine("Sqlite3Engine", "sqlite3") # type: ignore[misc,assignment]
|
||||
|
||||
|
||||
def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine:
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast
|
||||
|
||||
import psycopg2.extensions
|
||||
import psycopg
|
||||
import psycopg.errors
|
||||
import psycopg.sql
|
||||
|
||||
from synapse.storage.engines._base import (
|
||||
BaseDatabaseEngine,
|
||||
|
@ -31,28 +33,26 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
|
||||
class PsycopgEngine(BaseDatabaseEngine[psycopg.Connection]):
|
||||
def __init__(self, database_config: Mapping[str, Any]):
|
||||
super().__init__(psycopg2, database_config)
|
||||
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
|
||||
super().__init__(psycopg, database_config)
|
||||
# psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
|
||||
|
||||
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
|
||||
# actually want to use bytes than wrap it in `bytearray`.
|
||||
def _disable_bytes_adapter(_: bytes) -> NoReturn:
|
||||
raise Exception("Passing bytes to DB is disabled.")
|
||||
# def _disable_bytes_adapter(_: bytes) -> NoReturn:
|
||||
# raise Exception("Passing bytes to DB is disabled.")
|
||||
|
||||
psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
|
||||
# psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
|
||||
self.synchronous_commit: bool = database_config.get("synchronous_commit", True)
|
||||
self._version: Optional[int] = None # unknown as yet
|
||||
|
||||
self.isolation_level_map: Mapping[int, int] = {
|
||||
IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED,
|
||||
IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
|
||||
IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE,
|
||||
self.isolation_level_map: Mapping[int, psycopg.IsolationLevel] = {
|
||||
IsolationLevel.READ_COMMITTED: psycopg.IsolationLevel.READ_COMMITTED,
|
||||
IsolationLevel.REPEATABLE_READ: psycopg.IsolationLevel.REPEATABLE_READ,
|
||||
IsolationLevel.SERIALIZABLE: psycopg.IsolationLevel.SERIALIZABLE,
|
||||
}
|
||||
self.default_isolation_level = (
|
||||
psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||
)
|
||||
self.default_isolation_level = psycopg.IsolationLevel.REPEATABLE_READ
|
||||
self.config = database_config
|
||||
|
||||
@property
|
||||
|
@ -68,14 +68,14 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
|
|||
|
||||
def check_database(
|
||||
self,
|
||||
db_conn: psycopg2.extensions.connection,
|
||||
db_conn: psycopg.Connection,
|
||||
allow_outdated_version: bool = False,
|
||||
) -> None:
|
||||
# Get the version of PostgreSQL that we're using. As per the psycopg2
|
||||
# docs: The number is formed by converting the major, minor, and
|
||||
# revision numbers into two-decimal-digit numbers and appending them
|
||||
# together. For example, version 8.1.5 will be returned as 80105
|
||||
self._version = cast(int, db_conn.server_version)
|
||||
self._version = cast(int, db_conn.info.server_version)
|
||||
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
|
||||
|
||||
# Are we on a supported PostgreSQL version?
|
||||
|
@ -140,6 +140,9 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
|
|||
)
|
||||
|
||||
def convert_param_style(self, sql: str) -> str:
|
||||
if isinstance(sql, psycopg.sql.Composed):
|
||||
return sql
|
||||
|
||||
return sql.replace("?", "%s")
|
||||
|
||||
def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
|
||||
|
@ -186,14 +189,14 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
|
|||
return True
|
||||
|
||||
def is_deadlock(self, error: Exception) -> bool:
|
||||
if isinstance(error, psycopg2.DatabaseError):
|
||||
if isinstance(error, psycopg.errors.Error):
|
||||
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
|
||||
# "40001" serialization_failure
|
||||
# "40P01" deadlock_detected
|
||||
return error.pgcode in ["40001", "40P01"]
|
||||
return error.sqlstate in ["40001", "40P01"]
|
||||
return False
|
||||
|
||||
def is_connection_closed(self, conn: psycopg2.extensions.connection) -> bool:
|
||||
def is_connection_closed(self, conn: psycopg.Connection) -> bool:
|
||||
return bool(conn.closed)
|
||||
|
||||
def lock_table(self, txn: Cursor, table: str) -> None:
|
||||
|
@ -213,19 +216,19 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
|
|||
else:
|
||||
return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
|
||||
|
||||
def in_transaction(self, conn: psycopg2.extensions.connection) -> bool:
|
||||
return conn.status != psycopg2.extensions.STATUS_READY
|
||||
def in_transaction(self, conn: psycopg.Connection) -> bool:
|
||||
return conn.info.transaction_status != psycopg.pq.TransactionStatus.IDLE
|
||||
|
||||
def attempt_to_set_autocommit(
|
||||
self, conn: psycopg2.extensions.connection, autocommit: bool
|
||||
self, conn: psycopg.Connection, autocommit: bool
|
||||
) -> None:
|
||||
return conn.set_session(autocommit=autocommit)
|
||||
conn.autocommit = autocommit
|
||||
|
||||
def attempt_to_set_isolation_level(
|
||||
self, conn: psycopg2.extensions.connection, isolation_level: Optional[int]
|
||||
self, conn: psycopg.Connection, isolation_level: Optional[int]
|
||||
) -> None:
|
||||
if isolation_level is None:
|
||||
isolation_level = self.default_isolation_level
|
||||
pg_isolation_level = self.default_isolation_level
|
||||
else:
|
||||
isolation_level = self.isolation_level_map[isolation_level]
|
||||
return conn.set_isolation_level(isolation_level)
|
||||
pg_isolation_level = self.isolation_level_map[isolation_level]
|
||||
conn.isolation_level = pg_isolation_level
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
Adds a postgres SEQUENCE for generating application service transaction IDs.
|
||||
"""
|
||||
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.engines import PsycopgEngine
|
||||
|
||||
|
||||
def run_create(cur, database_engine, *args, **kwargs):
|
||||
|
@ -38,7 +38,14 @@ def run_create(cur, database_engine, *args, **kwargs):
|
|||
|
||||
start_val = max(last_txn_max, txn_max) + 1
|
||||
|
||||
cur.execute(
|
||||
"CREATE SEQUENCE application_services_txn_id_seq START WITH %s",
|
||||
(start_val,),
|
||||
)
|
||||
# XXX This is a hack.
|
||||
sql = f"CREATE SEQUENCE application_services_txn_id_seq START WITH {start_val}"
|
||||
args = ()
|
||||
|
||||
if isinstance(database_engine, PsycopgEngine):
|
||||
import psycopg.sql
|
||||
cur.execute(
|
||||
psycopg.sql.SQL(sql).format(args)
|
||||
)
|
||||
else:
|
||||
cur.execute(sql, args)
|
||||
|
|
|
@ -33,7 +33,7 @@ class Cursor(Protocol):
|
|||
def fetchone(self) -> Optional[Tuple]:
|
||||
...
|
||||
|
||||
def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
|
||||
def fetchmany(self, size: int = ...) -> List[Tuple]:
|
||||
...
|
||||
|
||||
def fetchall(self) -> List[Tuple]:
|
||||
|
@ -42,22 +42,7 @@ class Cursor(Protocol):
|
|||
@property
|
||||
def description(
|
||||
self,
|
||||
) -> Optional[
|
||||
Sequence[
|
||||
# Note that this is an approximate typing based on sqlite3 and other
|
||||
# drivers, and may not be entirely accurate.
|
||||
# FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description
|
||||
Tuple[
|
||||
str,
|
||||
Optional[Any],
|
||||
Optional[int],
|
||||
Optional[int],
|
||||
Optional[int],
|
||||
Optional[int],
|
||||
Optional[int],
|
||||
]
|
||||
]
|
||||
]:
|
||||
) -> Optional[Sequence[Any]]:
|
||||
...
|
||||
|
||||
@property
|
||||
|
|
|
@ -83,11 +83,11 @@ def setupdb() -> None:
|
|||
|
||||
# Set up in the db
|
||||
db_conn = db_engine.module.connect(
|
||||
database=POSTGRES_BASE_DB,
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
password=POSTGRES_PASSWORD,
|
||||
dbname=POSTGRES_BASE_DB,
|
||||
)
|
||||
logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
|
||||
prepare_database(logging_conn, db_engine, None)
|
||||
|
|
Loading…
Reference in New Issue