Don't use separate copy_read method.
parent
a072285e9d
commit
a280d117dc
|
@ -14,6 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import time
|
||||
import types
|
||||
|
@ -62,7 +63,7 @@ from synapse.storage.engines import (
|
|||
BaseDatabaseEngine,
|
||||
Psycopg2Engine,
|
||||
PsycopgEngine,
|
||||
Sqlite3Engine,
|
||||
Sqlite3Engine, PostgresEngine,
|
||||
)
|
||||
from synapse.storage.types import Connection, Cursor, SQLQueryParameters
|
||||
from synapse.util.async_helpers import delay_cancellation
|
||||
|
@ -399,7 +400,7 @@ class LoggingTransaction:
|
|||
def execute_values(
|
||||
self,
|
||||
sql: str,
|
||||
values: Iterable[Iterable[Any]],
|
||||
values: Sequence[Sequence[Any]],
|
||||
template: Optional[str] = None,
|
||||
fetch: bool = True,
|
||||
) -> List[Tuple]:
|
||||
|
@ -412,19 +413,43 @@ class LoggingTransaction:
|
|||
The `template` is the snippet to merge to every item in argslist to
|
||||
compose the query.
|
||||
"""
|
||||
assert isinstance(self.database_engine, Psycopg2Engine)
|
||||
assert isinstance(self.database_engine, PostgresEngine)
|
||||
|
||||
from psycopg2.extras import execute_values
|
||||
if isinstance(self.database_engine, Psycopg2Engine):
|
||||
|
||||
return self._do_execute(
|
||||
# TODO: is it safe for values to be Iterable[Iterable[Any]] here?
|
||||
# https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence]
|
||||
lambda the_sql, the_values: execute_values(
|
||||
self.txn, the_sql, the_values, template=template, fetch=fetch
|
||||
),
|
||||
sql,
|
||||
values,
|
||||
)
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
return self._do_execute(
|
||||
# TODO: is it safe for values to be Iterable[Iterable[Any]] here?
|
||||
# https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence]
|
||||
lambda the_sql, the_values: execute_values(
|
||||
self.txn, the_sql, the_values, template=template, fetch=fetch
|
||||
),
|
||||
sql,
|
||||
values,
|
||||
)
|
||||
else:
|
||||
# We use fetch = False to mean a writable query. You *might* be able
|
||||
# to morph that into a COPY (...) FROM STDIN, but it isn't worth the
|
||||
# effort for the few places we set fetch = False.
|
||||
assert fetch is True
|
||||
|
||||
# execute_values requires a single replacement, but we need to expand it
|
||||
# for COPY. This assumes all inner sequences are the same length.
|
||||
value_str = "(" + ", ".join("?" for _ in next(iter(values))) + ")"
|
||||
sql = sql.replace("?", ", ".join(value_str for _ in values))
|
||||
|
||||
# Wrap the SQL in the COPY statement.
|
||||
sql = f"COPY ({sql}) TO STDOUT"
|
||||
|
||||
def f(
|
||||
the_sql: str, the_args: Sequence[Sequence[Any]]
|
||||
) -> Iterable[Tuple[Any, ...]]:
|
||||
with self.txn.copy(the_sql, the_args) as copy:
|
||||
yield from copy.rows()
|
||||
|
||||
# Flatten the values.
|
||||
return self._do_execute(f, sql, list(itertools.chain.from_iterable(values)))
|
||||
|
||||
def copy_write(
|
||||
self, sql: str, args: Iterable[Any], values: Iterable[Iterable[Any]]
|
||||
|
@ -441,20 +466,6 @@ class LoggingTransaction:
|
|||
|
||||
self._do_execute(f, sql, args, values)
|
||||
|
||||
def copy_read(
|
||||
self, sql: str, args: Iterable[Iterable[Any]]
|
||||
) -> Iterable[Tuple[Any, ...]]:
|
||||
"""Corresponds to a PostgreSQL COPY (...) TO STDOUT call."""
|
||||
assert isinstance(self.database_engine, PsycopgEngine)
|
||||
|
||||
def f(
|
||||
the_sql: str, the_args: Iterable[Iterable[Any]]
|
||||
) -> Iterable[Tuple[Any, ...]]:
|
||||
with self.txn.copy(the_sql, the_args) as copy:
|
||||
yield from copy.rows()
|
||||
|
||||
return self._do_execute(f, sql, args)
|
||||
|
||||
def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
|
||||
self._do_execute(self.txn.execute, sql, parameters)
|
||||
|
||||
|
@ -1187,7 +1198,7 @@ class DatabasePool:
|
|||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keys: Collection[str],
|
||||
values: Iterable[Iterable[Any]],
|
||||
values: Sequence[Sequence[Any]],
|
||||
) -> None:
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
|
|
|
@ -311,34 +311,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
results = set()
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
if isinstance(self.database_engine, Psycopg2Engine):
|
||||
# We can use `execute_values` to efficiently fetch the gaps when
|
||||
# using postgres.
|
||||
sql = """
|
||||
SELECT event_id
|
||||
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
|
||||
WHERE
|
||||
c.chain_id = l.chain_id
|
||||
AND sequence_number <= max_seq
|
||||
"""
|
||||
rows = txn.execute_values(sql, chains.items())
|
||||
else:
|
||||
sql = """
|
||||
COPY (
|
||||
SELECT event_id
|
||||
FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, max_seq)
|
||||
WHERE
|
||||
c.chain_id = l.chain_id
|
||||
AND sequence_number <= max_seq
|
||||
)
|
||||
TO STDOUT
|
||||
""" % (
|
||||
", ".join("(?, ?)" for _ in chains)
|
||||
)
|
||||
# Flatten the arguments.
|
||||
rows = txn.copy_read(
|
||||
sql, list(itertools.chain.from_iterable(chains.items()))
|
||||
)
|
||||
# We can use `execute_values` to efficiently fetch the gaps when
|
||||
# using postgres.
|
||||
sql = """
|
||||
SELECT event_id
|
||||
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
|
||||
WHERE
|
||||
c.chain_id = l.chain_id
|
||||
AND sequence_number <= max_seq
|
||||
"""
|
||||
rows = txn.execute_values(sql, chains.items())
|
||||
results.update(r for r, in rows)
|
||||
else:
|
||||
# For SQLite we just fall back to doing a noddy for loop.
|
||||
|
@ -599,38 +581,22 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
return result
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# We can use `execute_values` to efficiently fetch the gaps when
|
||||
# using postgres.
|
||||
sql = """
|
||||
SELECT event_id
|
||||
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
|
||||
WHERE
|
||||
c.chain_id = l.chain_id
|
||||
AND min_seq < sequence_number AND sequence_number <= max_seq
|
||||
"""
|
||||
|
||||
args = [
|
||||
(chain_id, min_no, max_no)
|
||||
for chain_id, (min_no, max_no) in chain_to_gap.items()
|
||||
]
|
||||
|
||||
if isinstance(self.database_engine, Psycopg2Engine):
|
||||
# We can use `execute_values` to efficiently fetch the gaps when
|
||||
# using postgres.
|
||||
sql = """
|
||||
SELECT event_id
|
||||
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
|
||||
WHERE
|
||||
c.chain_id = l.chain_id
|
||||
AND min_seq < sequence_number AND sequence_number <= max_seq
|
||||
"""
|
||||
|
||||
rows = txn.execute_values(sql, args)
|
||||
else:
|
||||
sql = """
|
||||
COPY (
|
||||
SELECT event_id
|
||||
FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, min_seq, max_seq)
|
||||
WHERE
|
||||
c.chain_id = l.chain_id
|
||||
AND min_seq < sequence_number AND sequence_number <= max_seq
|
||||
)
|
||||
TO STDOUT
|
||||
""" % (
|
||||
", ".join("(?, ?, ?)" for _ in args)
|
||||
)
|
||||
# Flatten the arguments.
|
||||
rows = txn.copy_read(sql, list(itertools.chain.from_iterable(args)))
|
||||
rows = txn.execute_values(sql, args)
|
||||
result.update(r for r, in rows)
|
||||
else:
|
||||
# For SQLite we just fall back to doing a noddy for loop.
|
||||
|
|
Loading…
Reference in New Issue