From 4a0dfb336f93cedb402287636be80fe24af41b46 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 22 Sep 2023 15:26:21 -0400 Subject: [PATCH] temp --- synapse/storage/database.py | 30 ++++++-- .../databases/main/event_federation.py | 12 +++- synapse/storage/databases/main/relations.py | 4 +- .../storage/databases/main/user_directory.py | 71 +++++++++++-------- synapse/storage/engines/postgres.py | 7 +- synapse/storage/engines/psycopg.py | 6 ++ synapse/storage/util/id_generators.py | 2 +- tests/server.py | 6 +- 8 files changed, 93 insertions(+), 45 deletions(-) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 67c81a6cf2..dc5bf4d482 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -62,7 +62,7 @@ from synapse.storage.engines import ( BaseDatabaseEngine, PostgresEngine, PsycopgEngine, - Sqlite3Engine, + Sqlite3Engine, Psycopg2Engine, ) from synapse.storage.types import Connection, Cursor, SQLQueryParameters from synapse.util.async_helpers import delay_cancellation @@ -389,7 +389,7 @@ class LoggingTransaction: More efficient than `executemany` on PostgreSQL """ - if isinstance(self.database_engine, PostgresEngine): + if isinstance(self.database_engine, Psycopg2Engine): from psycopg2.extras import execute_batch # TODO: is it safe for values to be Iterable[Iterable[Any]] here? @@ -398,6 +398,8 @@ class LoggingTransaction: self._do_execute( lambda the_sql: execute_batch(self.txn, the_sql, args), sql ) + + # TODO Can psycopg3 do anything better? else: # TODO: is it safe for values to be Iterable[Iterable[Any]] here? # https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#sqlite3.Cursor.executemany @@ -422,7 +424,8 @@ class LoggingTransaction: The `template` is the snippet to merge to every item in argslist to compose the query. """ - assert isinstance(self.database_engine, PostgresEngine) + assert isinstance(self.database_engine, Psycopg2Engine) + from psycopg2.extras import execute_values return self._do_execute( @@ -435,6 +438,14 @@ class LoggingTransaction: values, ) + def copy( + self, + sql: str, args: Iterable[Iterable[Any]] + ) -> None: + with self.txn.copy(sql) as copy: + for record in args: + copy.write_row(record) + def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None: self._do_execute(self.txn.execute, sql, parameters) @@ -1180,7 +1191,7 @@ class DatabasePool: values: for each row, a list of values in the same order as `keys` """ - if isinstance(txn.database_engine, PostgresEngine): + if isinstance(txn.database_engine, Psycopg2Engine): # We use `execute_values` as it can be a lot faster than `execute_batch`, # but it's only available on postgres. sql = "INSERT INTO %s (%s) VALUES ?" % ( @@ -1189,6 +1200,13 @@ class DatabasePool: ) txn.execute_values(sql, values, fetch=False) + + elif isinstance(txn.database_engine, PsycopgEngine): + sql = "COPY %s (%s) FROM STDIN" % ( + table, ", ".join(k for k in keys), + ) + txn.copy(sql, values) + else: sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, @@ -1606,7 +1624,7 @@ class DatabasePool: for x, y in zip(key_values, value_values): args.append(tuple(x) + tuple(y)) - if isinstance(txn.database_engine, PostgresEngine): + if isinstance(txn.database_engine, Psycopg2Engine): # We use `execute_values` as it can be a lot faster than `execute_batch`, # but it's only available on postgres. sql = "INSERT INTO %s (%s) VALUES ? ON CONFLICT (%s) DO %s" % ( @@ -2362,7 +2380,7 @@ class DatabasePool: values: for each row, a list of values in the same order as `keys` """ - if isinstance(txn.database_engine, PostgresEngine): + if isinstance(txn.database_engine, Psycopg2Engine): # We use `execute_values` as it can be a lot faster than `execute_batch`, # but it's only available on postgres. sql = "DELETE FROM %s WHERE (%s) IN (VALUES ?)" % ( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index afffa54985..c908352b50 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -47,7 +47,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.engines import PostgresEngine, Sqlite3Engine, Psycopg2Engine from synapse.types import JsonDict, StrCollection from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -321,7 +321,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas AND sequence_number <= max_seq """ - rows = txn.execute_values(sql, chains.items()) + if isinstance(self.database_engine, Psycopg2Engine): + rows = txn.execute_values(sql, chains.items()) + else: + rows = txn.executemany(sql, chains.items()) results.update(r for r, in rows) else: # For SQLite we just fall back to doing a noddy for loop. @@ -597,7 +600,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas for chain_id, (min_no, max_no) in chain_to_gap.items() ] - rows = txn.execute_values(sql, args) + if isinstance(self.database_engine, Psycopg2Engine): + rows = txn.execute_values(sql, args) + else: + rows = txn.executemany(sql, args) result.update(r for r, in rows) else: # For SQLite we just fall back to doing a noddy for loop. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index b67f780c10..c04d45bdb5 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -46,7 +46,7 @@ from synapse.storage.databases.main.stream import ( generate_pagination_bounds, generate_pagination_where_clause, ) -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import PostgresEngine, Psycopg2Engine from synapse.types import JsonDict, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList @@ -139,7 +139,7 @@ class RelationsWorkerStore(SQLBaseStore): ON CONFLICT (room_id, thread_id) DO NOTHING """ - if isinstance(txn.database_engine, PostgresEngine): + if isinstance(txn.database_engine, Psycopg2Engine): txn.execute_values(sql % ("?",), rows, fetch=False) else: txn.execute_batch(sql % ("(?, ?, ?, ?, ?)",), rows) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f0dc31fee6..f589bc781c 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -54,7 +54,8 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.engines import PostgresEngine, Sqlite3Engine, PsycopgEngine, \ + Psycopg2Engine from synapse.types import ( JsonDict, UserID, @@ -717,35 +718,49 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): if isinstance(self.database_engine, PostgresEngine): # We weight the localpart most highly, then display name and finally # server name - template = """ - ( - %s, - setweight(to_tsvector('simple', %s), 'A') - || setweight(to_tsvector('simple', %s), 'D') - || setweight(to_tsvector('simple', COALESCE(%s, '')), 'B') - ) - """ sql = """ - INSERT INTO user_directory_search(user_id, vector) - VALUES ? ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector - """ - txn.execute_values( - sql, - [ - ( - p.user_id, - get_localpart_from_id(p.user_id), - get_domain_from_id(p.user_id), - _filter_text_for_index(p.display_name) - if p.display_name - else None, - ) - for p in profiles - ], - template=template, - fetch=False, - ) + INSERT INTO user_directory_search(user_id, vector) + VALUES + ( + ?, + setweight(to_tsvector('simple', ?), 'A') + || setweight(to_tsvector('simple', ?), 'D') + || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') + ) + ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector + """ + if isinstance(self.database_engine, Psycopg2Engine): + txn.execute_values( + sql, + [ + ( + p.user_id, + get_localpart_from_id(p.user_id), + get_domain_from_id(p.user_id), + _filter_text_for_index(p.display_name) + if p.display_name + else None, + ) + for p in profiles + ], + fetch=False, + ) + if isinstance(self.database_engine, PsycopgEngine): + txn.executemany( + sql, + [ + ( + p.user_id, + get_localpart_from_id(p.user_id), + get_domain_from_id(p.user_id), + _filter_text_for_index(p.display_name) + if p.display_name + else None, + ) + for p in profiles + ], + ) elif isinstance(self.database_engine, Sqlite3Engine): values = [] for p in profiles: diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index ec143181fa..c28730c4ca 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -164,8 +164,11 @@ class PostgresEngine(BaseDatabaseEngine[ConnectionType, CursorType], metaclass=a # Abort really long-running statements and turn them into errors. if self.statement_timeout is not None: - cursor.execute(sql.SQL("SET statement_timeout TO {}").format(self.statement_timeout)) - #cursor.execute("SELECT set_config( 'statement_timeout', ?, false)", (self.statement_timeout,)) + # TODO Avoid a circular import, this needs to be abstracted. + if self.__class__.__name__ == "Psycopg2Engine": + cursor.execute("SET statement_timeout TO ?", (self.statement_timeout,)) + else: + cursor.execute(sql.SQL("SET statement_timeout TO {}").format(self.statement_timeout)) cursor.close() db_conn.commit() diff --git a/synapse/storage/engines/psycopg.py b/synapse/storage/engines/psycopg.py index e9cc02e795..ddd745217f 100644 --- a/synapse/storage/engines/psycopg.py +++ b/synapse/storage/engines/psycopg.py @@ -19,6 +19,8 @@ import psycopg import psycopg.errors import psycopg.sql +from twisted.enterprise.adbapi import Connection as TxConnection + from synapse.storage.engines import PostgresEngine from synapse.storage.engines._base import ( IsolationLevel, @@ -73,6 +75,10 @@ class PsycopgEngine(PostgresEngine[psycopg.Connection, psycopg.Cursor]): def attempt_to_set_autocommit( self, conn: psycopg.Connection, autocommit: bool ) -> None: + # Sometimes this gets called with a Twisted connection instead, unwrap + # it because it doesn't support __setattr__. + if isinstance(conn, TxConnection): + conn = conn._connection conn.autocommit = autocommit def attempt_to_set_isolation_level( diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index d2c874b9a8..aa4fa40c9c 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -811,7 +811,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): "agg": "GREATEST" if self._positive else "LEAST", } - pos = (self.get_current_token_for_writer(self._instance_name),) + pos = self.get_current_token_for_writer(self._instance_name) txn.execute(sql, (self._stream_name, self._instance_name, pos)) diff --git a/tests/server.py b/tests/server.py index c47b536efc..7f5b5ba8f8 100644 --- a/tests/server.py +++ b/tests/server.py @@ -960,11 +960,11 @@ def setup_test_homeserver( test_db = "synapse_test_%s" % uuid.uuid4().hex if USE_POSTGRES_FOR_TESTS == "psycopg": - name = "psycopg" + db_type = "psycopg" else: - name = "psycopg2" + db_type = "psycopg2" database_config = { - "name": name, + "name": db_type, "args": { "dbname": test_db, "host": POSTGRES_HOST,