Don't use separate copy_read method.
parent
a072285e9d
commit
a280d117dc
|
@ -14,6 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import types
|
import types
|
||||||
|
@ -62,7 +63,7 @@ from synapse.storage.engines import (
|
||||||
BaseDatabaseEngine,
|
BaseDatabaseEngine,
|
||||||
Psycopg2Engine,
|
Psycopg2Engine,
|
||||||
PsycopgEngine,
|
PsycopgEngine,
|
||||||
Sqlite3Engine,
|
Sqlite3Engine, PostgresEngine,
|
||||||
)
|
)
|
||||||
from synapse.storage.types import Connection, Cursor, SQLQueryParameters
|
from synapse.storage.types import Connection, Cursor, SQLQueryParameters
|
||||||
from synapse.util.async_helpers import delay_cancellation
|
from synapse.util.async_helpers import delay_cancellation
|
||||||
|
@ -399,7 +400,7 @@ class LoggingTransaction:
|
||||||
def execute_values(
|
def execute_values(
|
||||||
self,
|
self,
|
||||||
sql: str,
|
sql: str,
|
||||||
values: Iterable[Iterable[Any]],
|
values: Sequence[Sequence[Any]],
|
||||||
template: Optional[str] = None,
|
template: Optional[str] = None,
|
||||||
fetch: bool = True,
|
fetch: bool = True,
|
||||||
) -> List[Tuple]:
|
) -> List[Tuple]:
|
||||||
|
@ -412,19 +413,43 @@ class LoggingTransaction:
|
||||||
The `template` is the snippet to merge to every item in argslist to
|
The `template` is the snippet to merge to every item in argslist to
|
||||||
compose the query.
|
compose the query.
|
||||||
"""
|
"""
|
||||||
assert isinstance(self.database_engine, Psycopg2Engine)
|
assert isinstance(self.database_engine, PostgresEngine)
|
||||||
|
|
||||||
from psycopg2.extras import execute_values
|
if isinstance(self.database_engine, Psycopg2Engine):
|
||||||
|
|
||||||
return self._do_execute(
|
from psycopg2.extras import execute_values
|
||||||
# TODO: is it safe for values to be Iterable[Iterable[Any]] here?
|
|
||||||
# https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence]
|
return self._do_execute(
|
||||||
lambda the_sql, the_values: execute_values(
|
# TODO: is it safe for values to be Iterable[Iterable[Any]] here?
|
||||||
self.txn, the_sql, the_values, template=template, fetch=fetch
|
# https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence]
|
||||||
),
|
lambda the_sql, the_values: execute_values(
|
||||||
sql,
|
self.txn, the_sql, the_values, template=template, fetch=fetch
|
||||||
values,
|
),
|
||||||
)
|
sql,
|
||||||
|
values,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# We use fetch = False to mean a writable query. You *might* be able
|
||||||
|
# to morph that into a COPY (...) FROM STDIN, but it isn't worth the
|
||||||
|
# effort for the few places we set fetch = False.
|
||||||
|
assert fetch is True
|
||||||
|
|
||||||
|
# execute_values requires a single replacement, but we need to expand it
|
||||||
|
# for COPY. This assumes all inner sequences are the same length.
|
||||||
|
value_str = "(" + ", ".join("?" for _ in next(iter(values))) + ")"
|
||||||
|
sql = sql.replace("?", ", ".join(value_str for _ in values))
|
||||||
|
|
||||||
|
# Wrap the SQL in the COPY statement.
|
||||||
|
sql = f"COPY ({sql}) TO STDOUT"
|
||||||
|
|
||||||
|
def f(
|
||||||
|
the_sql: str, the_args: Sequence[Sequence[Any]]
|
||||||
|
) -> Iterable[Tuple[Any, ...]]:
|
||||||
|
with self.txn.copy(the_sql, the_args) as copy:
|
||||||
|
yield from copy.rows()
|
||||||
|
|
||||||
|
# Flatten the values.
|
||||||
|
return self._do_execute(f, sql, list(itertools.chain.from_iterable(values)))
|
||||||
|
|
||||||
def copy_write(
|
def copy_write(
|
||||||
self, sql: str, args: Iterable[Any], values: Iterable[Iterable[Any]]
|
self, sql: str, args: Iterable[Any], values: Iterable[Iterable[Any]]
|
||||||
|
@ -441,20 +466,6 @@ class LoggingTransaction:
|
||||||
|
|
||||||
self._do_execute(f, sql, args, values)
|
self._do_execute(f, sql, args, values)
|
||||||
|
|
||||||
def copy_read(
|
|
||||||
self, sql: str, args: Iterable[Iterable[Any]]
|
|
||||||
) -> Iterable[Tuple[Any, ...]]:
|
|
||||||
"""Corresponds to a PostgreSQL COPY (...) TO STDOUT call."""
|
|
||||||
assert isinstance(self.database_engine, PsycopgEngine)
|
|
||||||
|
|
||||||
def f(
|
|
||||||
the_sql: str, the_args: Iterable[Iterable[Any]]
|
|
||||||
) -> Iterable[Tuple[Any, ...]]:
|
|
||||||
with self.txn.copy(the_sql, the_args) as copy:
|
|
||||||
yield from copy.rows()
|
|
||||||
|
|
||||||
return self._do_execute(f, sql, args)
|
|
||||||
|
|
||||||
def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
|
def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
|
||||||
self._do_execute(self.txn.execute, sql, parameters)
|
self._do_execute(self.txn.execute, sql, parameters)
|
||||||
|
|
||||||
|
@ -1187,7 +1198,7 @@ class DatabasePool:
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
table: str,
|
table: str,
|
||||||
keys: Collection[str],
|
keys: Collection[str],
|
||||||
values: Iterable[Iterable[Any]],
|
values: Sequence[Sequence[Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Executes an INSERT query on the named table.
|
"""Executes an INSERT query on the named table.
|
||||||
|
|
||||||
|
|
|
@ -311,34 +311,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
results = set()
|
results = set()
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
if isinstance(self.database_engine, Psycopg2Engine):
|
# We can use `execute_values` to efficiently fetch the gaps when
|
||||||
# We can use `execute_values` to efficiently fetch the gaps when
|
# using postgres.
|
||||||
# using postgres.
|
sql = """
|
||||||
sql = """
|
SELECT event_id
|
||||||
SELECT event_id
|
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
|
||||||
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
|
WHERE
|
||||||
WHERE
|
c.chain_id = l.chain_id
|
||||||
c.chain_id = l.chain_id
|
AND sequence_number <= max_seq
|
||||||
AND sequence_number <= max_seq
|
"""
|
||||||
"""
|
rows = txn.execute_values(sql, chains.items())
|
||||||
rows = txn.execute_values(sql, chains.items())
|
|
||||||
else:
|
|
||||||
sql = """
|
|
||||||
COPY (
|
|
||||||
SELECT event_id
|
|
||||||
FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, max_seq)
|
|
||||||
WHERE
|
|
||||||
c.chain_id = l.chain_id
|
|
||||||
AND sequence_number <= max_seq
|
|
||||||
)
|
|
||||||
TO STDOUT
|
|
||||||
""" % (
|
|
||||||
", ".join("(?, ?)" for _ in chains)
|
|
||||||
)
|
|
||||||
# Flatten the arguments.
|
|
||||||
rows = txn.copy_read(
|
|
||||||
sql, list(itertools.chain.from_iterable(chains.items()))
|
|
||||||
)
|
|
||||||
results.update(r for r, in rows)
|
results.update(r for r, in rows)
|
||||||
else:
|
else:
|
||||||
# For SQLite we just fall back to doing a noddy for loop.
|
# For SQLite we just fall back to doing a noddy for loop.
|
||||||
|
@ -599,38 +581,22 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
# We can use `execute_values` to efficiently fetch the gaps when
|
||||||
|
# using postgres.
|
||||||
|
sql = """
|
||||||
|
SELECT event_id
|
||||||
|
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
|
||||||
|
WHERE
|
||||||
|
c.chain_id = l.chain_id
|
||||||
|
AND min_seq < sequence_number AND sequence_number <= max_seq
|
||||||
|
"""
|
||||||
|
|
||||||
args = [
|
args = [
|
||||||
(chain_id, min_no, max_no)
|
(chain_id, min_no, max_no)
|
||||||
for chain_id, (min_no, max_no) in chain_to_gap.items()
|
for chain_id, (min_no, max_no) in chain_to_gap.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
if isinstance(self.database_engine, Psycopg2Engine):
|
rows = txn.execute_values(sql, args)
|
||||||
# We can use `execute_values` to efficiently fetch the gaps when
|
|
||||||
# using postgres.
|
|
||||||
sql = """
|
|
||||||
SELECT event_id
|
|
||||||
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
|
|
||||||
WHERE
|
|
||||||
c.chain_id = l.chain_id
|
|
||||||
AND min_seq < sequence_number AND sequence_number <= max_seq
|
|
||||||
"""
|
|
||||||
|
|
||||||
rows = txn.execute_values(sql, args)
|
|
||||||
else:
|
|
||||||
sql = """
|
|
||||||
COPY (
|
|
||||||
SELECT event_id
|
|
||||||
FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, min_seq, max_seq)
|
|
||||||
WHERE
|
|
||||||
c.chain_id = l.chain_id
|
|
||||||
AND min_seq < sequence_number AND sequence_number <= max_seq
|
|
||||||
)
|
|
||||||
TO STDOUT
|
|
||||||
""" % (
|
|
||||||
", ".join("(?, ?, ?)" for _ in args)
|
|
||||||
)
|
|
||||||
# Flatten the arguments.
|
|
||||||
rows = txn.copy_read(sql, list(itertools.chain.from_iterable(args)))
|
|
||||||
result.update(r for r, in rows)
|
result.update(r for r, in rows)
|
||||||
else:
|
else:
|
||||||
# For SQLite we just fall back to doing a noddy for loop.
|
# For SQLite we just fall back to doing a noddy for loop.
|
||||||
|
|
Loading…
Reference in New Issue