diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 3295ebc74d..6aeaedeaef 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -426,14 +426,24 @@ class LoggingTransaction: values, ) - def copy( + def copy_write( self, sql: str, args: Iterable[Iterable[Any]] ) -> None: + # TODO use _do_execute with self.txn.copy(sql) as copy: for record in args: copy.write_row(record) + def copy_read( + self, + sql: str, args: Iterable[Iterable[Any]] + ) -> Iterable[Tuple[Any, ...]]: + # TODO use _do_execute + sql = self.database_engine.convert_param_style(sql) + with self.txn.copy(sql, args) as copy: + yield from copy.rows() + def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None: self._do_execute(self.txn.execute, sql, parameters) @@ -1193,7 +1203,7 @@ class DatabasePool: sql = "COPY %s (%s) FROM STDIN" % ( table, ", ".join(k for k in keys), ) - txn.copy(sql, values) + txn.copy_write(sql, values) else: sql = "INSERT INTO %s (%s) VALUES(%s)" % ( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index c908352b50..e2e0717412 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -311,20 +311,30 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas results = set() if isinstance(self.database_engine, PostgresEngine): - # We can use `execute_values` to efficiently fetch the gaps when - # using postgres. - sql = """ - SELECT event_id - FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq) - WHERE - c.chain_id = l.chain_id - AND sequence_number <= max_seq - """ - if isinstance(self.database_engine, Psycopg2Engine): + # We can use `execute_values` to efficiently fetch the gaps when + # using postgres. + sql = """ + SELECT event_id + FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq) + WHERE + c.chain_id = l.chain_id + AND sequence_number <= max_seq + """ rows = txn.execute_values(sql, chains.items()) else: - rows = txn.executemany(sql, chains.items()) + sql = """ + COPY ( + SELECT event_id + FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, max_seq) + WHERE + c.chain_id = l.chain_id + AND sequence_number <= max_seq + ) + TO STDOUT + """ % (", ".join("(?, ?)" for _ in chains)) + # Flatten the arguments. + rows = txn.copy_read(sql, list(itertools.chain.from_iterable(chains.items()))) results.update(r for r, in rows) else: # For SQLite we just fall back to doing a noddy for loop. @@ -585,25 +595,36 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return result if isinstance(self.database_engine, PostgresEngine): - # We can use `execute_values` to efficiently fetch the gaps when - # using postgres. - sql = """ - SELECT event_id - FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq) - WHERE - c.chain_id = l.chain_id - AND min_seq < sequence_number AND sequence_number <= max_seq - """ - args = [ (chain_id, min_no, max_no) for chain_id, (min_no, max_no) in chain_to_gap.items() ] if isinstance(self.database_engine, Psycopg2Engine): + # We can use `execute_values` to efficiently fetch the gaps when + # using postgres. + sql = """ + SELECT event_id + FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq) + WHERE + c.chain_id = l.chain_id + AND min_seq < sequence_number AND sequence_number <= max_seq + """ + rows = txn.execute_values(sql, args) else: - rows = txn.executemany(sql, args) + sql = """ + COPY ( + SELECT event_id + FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, min_seq, max_seq) + WHERE + c.chain_id = l.chain_id + AND min_seq < sequence_number AND sequence_number <= max_seq + ) + TO STDOUT + """ % (", ".join("(?, ?, ?)" for _ in args)) + # Flatten the arguments. + rows = txn.copy_read(sql, list(itertools.chain.from_iterable(args))) result.update(r for r, in rows) else: # For SQLite we just fall back to doing a noddy for loop. diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f589bc781c..96379602ba 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -719,18 +719,21 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # We weight the localpart most highly, then display name and finally # server name - sql = """ - INSERT INTO user_directory_search(user_id, vector) - VALUES - ( - ?, - setweight(to_tsvector('simple', ?), 'A') - || setweight(to_tsvector('simple', ?), 'D') - || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') - ) - ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector - """ if isinstance(self.database_engine, Psycopg2Engine): + template = """ + ( + %s, + setweight(to_tsvector('simple', %s), 'A') + || setweight(to_tsvector('simple', %s), 'D') + || setweight(to_tsvector('simple', COALESCE(%s, '')), 'B') + ) + """ + + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES ? ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector + """ + txn.execute_values( sql, [ @@ -744,9 +747,22 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) for p in profiles ], + template=template, fetch=False, ) - if isinstance(self.database_engine, PsycopgEngine): + elif isinstance(self.database_engine, PsycopgEngine): + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES + ( + ?, + setweight(to_tsvector('simple', ?), 'A') + || setweight(to_tsvector('simple', ?), 'D') + || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') + ) + ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector + """ + txn.executemany( sql, [