1055 lines
		
	
	
		
			33 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
			
		
		
	
	
			1055 lines
		
	
	
		
			33 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
#!/usr/bin/env python
 | 
						|
# -*- coding: utf-8 -*-
 | 
						|
# Copyright 2015, 2016 OpenMarket Ltd
 | 
						|
# Copyright 2018 New Vector Ltd
 | 
						|
# Copyright 2019 The Matrix.org Foundation C.I.C.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
 | 
						|
import argparse
 | 
						|
import curses
 | 
						|
import logging
 | 
						|
import sys
 | 
						|
import time
 | 
						|
import traceback
 | 
						|
 | 
						|
from six import string_types
 | 
						|
 | 
						|
import yaml
 | 
						|
 | 
						|
from twisted.enterprise import adbapi
 | 
						|
from twisted.internet import defer, reactor
 | 
						|
 | 
						|
from synapse.config.homeserver import HomeServerConfig
 | 
						|
from synapse.logging.context import PreserveLoggingContext
 | 
						|
from synapse.storage._base import LoggingTransaction
 | 
						|
from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore
 | 
						|
from synapse.storage.data_stores.main.deviceinbox import (
 | 
						|
    DeviceInboxBackgroundUpdateStore,
 | 
						|
)
 | 
						|
from synapse.storage.data_stores.main.devices import DeviceBackgroundUpdateStore
 | 
						|
from synapse.storage.data_stores.main.events_bg_updates import (
 | 
						|
    EventsBackgroundUpdatesStore,
 | 
						|
)
 | 
						|
from synapse.storage.data_stores.main.media_repository import (
 | 
						|
    MediaRepositoryBackgroundUpdateStore,
 | 
						|
)
 | 
						|
from synapse.storage.data_stores.main.registration import (
 | 
						|
    RegistrationBackgroundUpdateStore,
 | 
						|
)
 | 
						|
from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore
 | 
						|
from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore
 | 
						|
from synapse.storage.data_stores.main.state import StateBackgroundUpdateStore
 | 
						|
from synapse.storage.data_stores.main.stats import StatsStore
 | 
						|
from synapse.storage.data_stores.main.user_directory import (
 | 
						|
    UserDirectoryBackgroundUpdateStore,
 | 
						|
)
 | 
						|
from synapse.storage.engines import create_engine
 | 
						|
from synapse.storage.prepare_database import prepare_database
 | 
						|
from synapse.util import Clock
 | 
						|
 | 
						|
logger = logging.getLogger("synapse_port_db")
 | 
						|
 | 
						|
 | 
						|
BOOLEAN_COLUMNS = {
 | 
						|
    "events": ["processed", "outlier", "contains_url"],
 | 
						|
    "rooms": ["is_public"],
 | 
						|
    "event_edges": ["is_state"],
 | 
						|
    "presence_list": ["accepted"],
 | 
						|
    "presence_stream": ["currently_active"],
 | 
						|
    "public_room_list_stream": ["visibility"],
 | 
						|
    "devices": ["hidden"],
 | 
						|
    "device_lists_outbound_pokes": ["sent"],
 | 
						|
    "users_who_share_rooms": ["share_private"],
 | 
						|
    "groups": ["is_public"],
 | 
						|
    "group_rooms": ["is_public"],
 | 
						|
    "group_users": ["is_public", "is_admin"],
 | 
						|
    "group_summary_rooms": ["is_public"],
 | 
						|
    "group_room_categories": ["is_public"],
 | 
						|
    "group_summary_users": ["is_public"],
 | 
						|
    "group_roles": ["is_public"],
 | 
						|
    "local_group_membership": ["is_publicised", "is_admin"],
 | 
						|
    "e2e_room_keys": ["is_verified"],
 | 
						|
    "account_validity": ["email_sent"],
 | 
						|
    "redactions": ["have_censored"],
 | 
						|
    "room_stats_state": ["is_federatable"],
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
APPEND_ONLY_TABLES = [
 | 
						|
    "event_reference_hashes",
 | 
						|
    "events",
 | 
						|
    "event_json",
 | 
						|
    "state_events",
 | 
						|
    "room_memberships",
 | 
						|
    "topics",
 | 
						|
    "room_names",
 | 
						|
    "rooms",
 | 
						|
    "local_media_repository",
 | 
						|
    "local_media_repository_thumbnails",
 | 
						|
    "remote_media_cache",
 | 
						|
    "remote_media_cache_thumbnails",
 | 
						|
    "redactions",
 | 
						|
    "event_edges",
 | 
						|
    "event_auth",
 | 
						|
    "received_transactions",
 | 
						|
    "sent_transactions",
 | 
						|
    "transaction_id_to_pdu",
 | 
						|
    "users",
 | 
						|
    "state_groups",
 | 
						|
    "state_groups_state",
 | 
						|
    "event_to_state_groups",
 | 
						|
    "rejections",
 | 
						|
    "event_search",
 | 
						|
    "presence_stream",
 | 
						|
    "push_rules_stream",
 | 
						|
    "ex_outlier_stream",
 | 
						|
    "cache_invalidation_stream",
 | 
						|
    "public_room_list_stream",
 | 
						|
    "state_group_edges",
 | 
						|
    "stream_ordering_to_exterm",
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
end_error_exec_info = None
 | 
						|
 | 
						|
 | 
						|
class Store(
 | 
						|
    ClientIpBackgroundUpdateStore,
 | 
						|
    DeviceInboxBackgroundUpdateStore,
 | 
						|
    DeviceBackgroundUpdateStore,
 | 
						|
    EventsBackgroundUpdatesStore,
 | 
						|
    MediaRepositoryBackgroundUpdateStore,
 | 
						|
    RegistrationBackgroundUpdateStore,
 | 
						|
    RoomMemberBackgroundUpdateStore,
 | 
						|
    SearchBackgroundUpdateStore,
 | 
						|
    StateBackgroundUpdateStore,
 | 
						|
    UserDirectoryBackgroundUpdateStore,
 | 
						|
    StatsStore,
 | 
						|
):
 | 
						|
    def __init__(self, db_conn, hs):
 | 
						|
        super().__init__(db_conn, hs)
 | 
						|
        self.db_pool = hs.get_db_pool()
 | 
						|
 | 
						|
    @defer.inlineCallbacks
 | 
						|
    def runInteraction(self, desc, func, *args, **kwargs):
 | 
						|
        def r(conn):
 | 
						|
            try:
 | 
						|
                i = 0
 | 
						|
                N = 5
 | 
						|
                while True:
 | 
						|
                    try:
 | 
						|
                        txn = conn.cursor()
 | 
						|
                        return func(
 | 
						|
                            LoggingTransaction(txn, desc, self.database_engine, [], []),
 | 
						|
                            *args,
 | 
						|
                            **kwargs
 | 
						|
                        )
 | 
						|
                    except self.database_engine.module.DatabaseError as e:
 | 
						|
                        if self.database_engine.is_deadlock(e):
 | 
						|
                            logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
 | 
						|
                            if i < N:
 | 
						|
                                i += 1
 | 
						|
                                conn.rollback()
 | 
						|
                                continue
 | 
						|
                        raise
 | 
						|
            except Exception as e:
 | 
						|
                logger.debug("[TXN FAIL] {%s} %s", desc, e)
 | 
						|
                raise
 | 
						|
 | 
						|
        with PreserveLoggingContext():
 | 
						|
            return (yield self.db_pool.runWithConnection(r))
 | 
						|
 | 
						|
    def execute(self, f, *args, **kwargs):
 | 
						|
        return self.runInteraction(f.__name__, f, *args, **kwargs)
 | 
						|
 | 
						|
    def execute_sql(self, sql, *args):
 | 
						|
        def r(txn):
 | 
						|
            txn.execute(sql, args)
 | 
						|
            return txn.fetchall()
 | 
						|
 | 
						|
        return self.runInteraction("execute_sql", r)
 | 
						|
 | 
						|
    def insert_many_txn(self, txn, table, headers, rows):
 | 
						|
        sql = "INSERT INTO %s (%s) VALUES (%s)" % (
 | 
						|
            table,
 | 
						|
            ", ".join(k for k in headers),
 | 
						|
            ", ".join("%s" for _ in headers),
 | 
						|
        )
 | 
						|
 | 
						|
        try:
 | 
						|
            txn.executemany(sql, rows)
 | 
						|
        except Exception:
 | 
						|
            logger.exception("Failed to insert: %s", table)
 | 
						|
            raise
 | 
						|
 | 
						|
 | 
						|
class MockHomeserver:
 | 
						|
    def __init__(self, config, database_engine, db_conn, db_pool):
 | 
						|
        self.database_engine = database_engine
 | 
						|
        self.db_conn = db_conn
 | 
						|
        self.db_pool = db_pool
 | 
						|
        self.clock = Clock(reactor)
 | 
						|
        self.config = config
 | 
						|
        self.hostname = config.server_name
 | 
						|
 | 
						|
    def get_db_conn(self):
 | 
						|
        return self.db_conn
 | 
						|
 | 
						|
    def get_db_pool(self):
 | 
						|
        return self.db_pool
 | 
						|
 | 
						|
    def get_clock(self):
 | 
						|
        return self.clock
 | 
						|
 | 
						|
 | 
						|
class Porter(object):
 | 
						|
    def __init__(self, **kwargs):
 | 
						|
        self.__dict__.update(kwargs)
 | 
						|
 | 
						|
    @defer.inlineCallbacks
 | 
						|
    def setup_table(self, table):
 | 
						|
        if table in APPEND_ONLY_TABLES:
 | 
						|
            # It's safe to just carry on inserting.
 | 
						|
            row = yield self.postgres_store._simple_select_one(
 | 
						|
                table="port_from_sqlite3",
 | 
						|
                keyvalues={"table_name": table},
 | 
						|
                retcols=("forward_rowid", "backward_rowid"),
 | 
						|
                allow_none=True,
 | 
						|
            )
 | 
						|
 | 
						|
            total_to_port = None
 | 
						|
            if row is None:
 | 
						|
                if table == "sent_transactions":
 | 
						|
                    forward_chunk, already_ported, total_to_port = (
 | 
						|
                        yield self._setup_sent_transactions()
 | 
						|
                    )
 | 
						|
                    backward_chunk = 0
 | 
						|
                else:
 | 
						|
                    yield self.postgres_store._simple_insert(
 | 
						|
                        table="port_from_sqlite3",
 | 
						|
                        values={
 | 
						|
                            "table_name": table,
 | 
						|
                            "forward_rowid": 1,
 | 
						|
                            "backward_rowid": 0,
 | 
						|
                        },
 | 
						|
                    )
 | 
						|
 | 
						|
                    forward_chunk = 1
 | 
						|
                    backward_chunk = 0
 | 
						|
                    already_ported = 0
 | 
						|
            else:
 | 
						|
                forward_chunk = row["forward_rowid"]
 | 
						|
                backward_chunk = row["backward_rowid"]
 | 
						|
 | 
						|
            if total_to_port is None:
 | 
						|
                already_ported, total_to_port = yield self._get_total_count_to_port(
 | 
						|
                    table, forward_chunk, backward_chunk
 | 
						|
                )
 | 
						|
        else:
 | 
						|
 | 
						|
            def delete_all(txn):
 | 
						|
                txn.execute(
 | 
						|
                    "DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
 | 
						|
                )
 | 
						|
                txn.execute("TRUNCATE %s CASCADE" % (table,))
 | 
						|
 | 
						|
            yield self.postgres_store.execute(delete_all)
 | 
						|
 | 
						|
            yield self.postgres_store._simple_insert(
 | 
						|
                table="port_from_sqlite3",
 | 
						|
                values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
 | 
						|
            )
 | 
						|
 | 
						|
            forward_chunk = 1
 | 
						|
            backward_chunk = 0
 | 
						|
 | 
						|
            already_ported, total_to_port = yield self._get_total_count_to_port(
 | 
						|
                table, forward_chunk, backward_chunk
 | 
						|
            )
 | 
						|
 | 
						|
        defer.returnValue(
 | 
						|
            (table, already_ported, total_to_port, forward_chunk, backward_chunk)
 | 
						|
        )
 | 
						|
 | 
						|
    @defer.inlineCallbacks
 | 
						|
    def handle_table(
 | 
						|
        self, table, postgres_size, table_size, forward_chunk, backward_chunk
 | 
						|
    ):
 | 
						|
        logger.info(
 | 
						|
            "Table %s: %i/%i (rows %i-%i) already ported",
 | 
						|
            table,
 | 
						|
            postgres_size,
 | 
						|
            table_size,
 | 
						|
            backward_chunk + 1,
 | 
						|
            forward_chunk - 1,
 | 
						|
        )
 | 
						|
 | 
						|
        if not table_size:
 | 
						|
            return
 | 
						|
 | 
						|
        self.progress.add_table(table, postgres_size, table_size)
 | 
						|
 | 
						|
        if table == "event_search":
 | 
						|
            yield self.handle_search_table(
 | 
						|
                postgres_size, table_size, forward_chunk, backward_chunk
 | 
						|
            )
 | 
						|
            return
 | 
						|
 | 
						|
        if table in (
 | 
						|
            "user_directory",
 | 
						|
            "user_directory_search",
 | 
						|
            "users_who_share_rooms",
 | 
						|
            "users_in_pubic_room",
 | 
						|
        ):
 | 
						|
            # We don't port these tables, as they're a faff and we can regenreate
 | 
						|
            # them anyway.
 | 
						|
            self.progress.update(table, table_size)  # Mark table as done
 | 
						|
            return
 | 
						|
 | 
						|
        if table == "user_directory_stream_pos":
 | 
						|
            # We need to make sure there is a single row, `(X, null), as that is
 | 
						|
            # what synapse expects to be there.
 | 
						|
            yield self.postgres_store._simple_insert(
 | 
						|
                table=table, values={"stream_id": None}
 | 
						|
            )
 | 
						|
            self.progress.update(table, table_size)  # Mark table as done
 | 
						|
            return
 | 
						|
 | 
						|
        forward_select = (
 | 
						|
            "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,)
 | 
						|
        )
 | 
						|
 | 
						|
        backward_select = (
 | 
						|
            "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,)
 | 
						|
        )
 | 
						|
 | 
						|
        do_forward = [True]
 | 
						|
        do_backward = [True]
 | 
						|
 | 
						|
        while True:
 | 
						|
 | 
						|
            def r(txn):
 | 
						|
                forward_rows = []
 | 
						|
                backward_rows = []
 | 
						|
                if do_forward[0]:
 | 
						|
                    txn.execute(forward_select, (forward_chunk, self.batch_size))
 | 
						|
                    forward_rows = txn.fetchall()
 | 
						|
                    if not forward_rows:
 | 
						|
                        do_forward[0] = False
 | 
						|
 | 
						|
                if do_backward[0]:
 | 
						|
                    txn.execute(backward_select, (backward_chunk, self.batch_size))
 | 
						|
                    backward_rows = txn.fetchall()
 | 
						|
                    if not backward_rows:
 | 
						|
                        do_backward[0] = False
 | 
						|
 | 
						|
                if forward_rows or backward_rows:
 | 
						|
                    headers = [column[0] for column in txn.description]
 | 
						|
                else:
 | 
						|
                    headers = None
 | 
						|
 | 
						|
                return headers, forward_rows, backward_rows
 | 
						|
 | 
						|
            headers, frows, brows = yield self.sqlite_store.runInteraction("select", r)
 | 
						|
 | 
						|
            if frows or brows:
 | 
						|
                if frows:
 | 
						|
                    forward_chunk = max(row[0] for row in frows) + 1
 | 
						|
                if brows:
 | 
						|
                    backward_chunk = min(row[0] for row in brows) - 1
 | 
						|
 | 
						|
                rows = frows + brows
 | 
						|
                rows = self._convert_rows(table, headers, rows)
 | 
						|
 | 
						|
                def insert(txn):
 | 
						|
                    self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
 | 
						|
 | 
						|
                    self.postgres_store._simple_update_one_txn(
 | 
						|
                        txn,
 | 
						|
                        table="port_from_sqlite3",
 | 
						|
                        keyvalues={"table_name": table},
 | 
						|
                        updatevalues={
 | 
						|
                            "forward_rowid": forward_chunk,
 | 
						|
                            "backward_rowid": backward_chunk,
 | 
						|
                        },
 | 
						|
                    )
 | 
						|
 | 
						|
                yield self.postgres_store.execute(insert)
 | 
						|
 | 
						|
                postgres_size += len(rows)
 | 
						|
 | 
						|
                self.progress.update(table, postgres_size)
 | 
						|
            else:
 | 
						|
                return
 | 
						|
 | 
						|
    @defer.inlineCallbacks
 | 
						|
    def handle_search_table(
 | 
						|
        self, postgres_size, table_size, forward_chunk, backward_chunk
 | 
						|
    ):
 | 
						|
        select = (
 | 
						|
            "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
 | 
						|
            " FROM event_search as es"
 | 
						|
            " INNER JOIN events AS e USING (event_id, room_id)"
 | 
						|
            " WHERE es.rowid >= ?"
 | 
						|
            " ORDER BY es.rowid LIMIT ?"
 | 
						|
        )
 | 
						|
 | 
						|
        while True:
 | 
						|
 | 
						|
            def r(txn):
 | 
						|
                txn.execute(select, (forward_chunk, self.batch_size))
 | 
						|
                rows = txn.fetchall()
 | 
						|
                headers = [column[0] for column in txn.description]
 | 
						|
 | 
						|
                return headers, rows
 | 
						|
 | 
						|
            headers, rows = yield self.sqlite_store.runInteraction("select", r)
 | 
						|
 | 
						|
            if rows:
 | 
						|
                forward_chunk = rows[-1][0] + 1
 | 
						|
 | 
						|
                # We have to treat event_search differently since it has a
 | 
						|
                # different structure in the two different databases.
 | 
						|
                def insert(txn):
 | 
						|
                    sql = (
 | 
						|
                        "INSERT INTO event_search (event_id, room_id, key,"
 | 
						|
                        " sender, vector, origin_server_ts, stream_ordering)"
 | 
						|
                        " VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
 | 
						|
                    )
 | 
						|
 | 
						|
                    rows_dict = []
 | 
						|
                    for row in rows:
 | 
						|
                        d = dict(zip(headers, row))
 | 
						|
                        if "\0" in d['value']:
 | 
						|
                            logger.warn('dropping search row %s', d)
 | 
						|
                        else:
 | 
						|
                            rows_dict.append(d)
 | 
						|
 | 
						|
                    txn.executemany(
 | 
						|
                        sql,
 | 
						|
                        [
 | 
						|
                            (
 | 
						|
                                row["event_id"],
 | 
						|
                                row["room_id"],
 | 
						|
                                row["key"],
 | 
						|
                                row["sender"],
 | 
						|
                                row["value"],
 | 
						|
                                row["origin_server_ts"],
 | 
						|
                                row["stream_ordering"],
 | 
						|
                            )
 | 
						|
                            for row in rows_dict
 | 
						|
                        ],
 | 
						|
                    )
 | 
						|
 | 
						|
                    self.postgres_store._simple_update_one_txn(
 | 
						|
                        txn,
 | 
						|
                        table="port_from_sqlite3",
 | 
						|
                        keyvalues={"table_name": "event_search"},
 | 
						|
                        updatevalues={
 | 
						|
                            "forward_rowid": forward_chunk,
 | 
						|
                            "backward_rowid": backward_chunk,
 | 
						|
                        },
 | 
						|
                    )
 | 
						|
 | 
						|
                yield self.postgres_store.execute(insert)
 | 
						|
 | 
						|
                postgres_size += len(rows)
 | 
						|
 | 
						|
                self.progress.update("event_search", postgres_size)
 | 
						|
 | 
						|
            else:
 | 
						|
                return
 | 
						|
 | 
						|
    def setup_db(self, db_config, database_engine):
 | 
						|
        db_conn = database_engine.module.connect(
 | 
						|
            **{
 | 
						|
                k: v
 | 
						|
                for k, v in db_config.get("args", {}).items()
 | 
						|
                if not k.startswith("cp_")
 | 
						|
            }
 | 
						|
        )
 | 
						|
 | 
						|
        prepare_database(db_conn, database_engine, config=None)
 | 
						|
 | 
						|
        db_conn.commit()
 | 
						|
 | 
						|
        return db_conn
 | 
						|
 | 
						|
    @defer.inlineCallbacks
 | 
						|
    def build_db_store(self, config):
 | 
						|
        """Builds and returns a database store using the provided configuration.
 | 
						|
 | 
						|
        Args:
 | 
						|
            config: The database configuration, i.e. a dict following the structure of
 | 
						|
                the "database" section of Synapse's configuration file.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            The built Store object.
 | 
						|
        """
 | 
						|
        engine = create_engine(config)
 | 
						|
 | 
						|
        self.progress.set_state("Preparing %s" % config["name"])
 | 
						|
        conn = self.setup_db(config, engine)
 | 
						|
 | 
						|
        db_pool = adbapi.ConnectionPool(
 | 
						|
            config["name"], **config["args"]
 | 
						|
        )
 | 
						|
 | 
						|
        hs = MockHomeserver(self.hs_config, engine, conn, db_pool)
 | 
						|
 | 
						|
        store = Store(conn, hs)
 | 
						|
 | 
						|
        yield store.runInteraction(
 | 
						|
            "%s_engine.check_database" % config["name"],
 | 
						|
            engine.check_database,
 | 
						|
        )
 | 
						|
 | 
						|
        return store
 | 
						|
 | 
						|
    @defer.inlineCallbacks
 | 
						|
    def run_background_updates_on_postgres(self):
 | 
						|
        # Manually apply all background updates on the PostgreSQL database.
 | 
						|
        postgres_ready = yield self.postgres_store.has_completed_background_updates()
 | 
						|
 | 
						|
        if not postgres_ready:
 | 
						|
            # Only say that we're running background updates when there are background
 | 
						|
            # updates to run.
 | 
						|
            self.progress.set_state("Running background updates on PostgreSQL")
 | 
						|
 | 
						|
        while not postgres_ready:
 | 
						|
            yield self.postgres_store.do_next_background_update(100)
 | 
						|
            postgres_ready = yield (
 | 
						|
                self.postgres_store.has_completed_background_updates()
 | 
						|
            )
 | 
						|
 | 
						|
    @defer.inlineCallbacks
 | 
						|
    def run(self):
 | 
						|
        try:
 | 
						|
            self.sqlite_store = yield self.build_db_store(self.sqlite_config)
 | 
						|
 | 
						|
            # Check if all background updates are done, abort if not.
 | 
						|
            updates_complete = yield self.sqlite_store.has_completed_background_updates()
 | 
						|
            if not updates_complete:
 | 
						|
                sys.stderr.write(
 | 
						|
                    "Pending background updates exist in the SQLite3 database."
 | 
						|
                    " Please start Synapse again and wait until every update has finished"
 | 
						|
                    " before running this script.\n"
 | 
						|
                )
 | 
						|
                defer.returnValue(None)
 | 
						|
 | 
						|
            self.postgres_store = yield self.build_db_store(
 | 
						|
                self.hs_config.database_config
 | 
						|
            )
 | 
						|
 | 
						|
            yield self.run_background_updates_on_postgres()
 | 
						|
 | 
						|
            self.progress.set_state("Creating port tables")
 | 
						|
 | 
						|
            def create_port_table(txn):
 | 
						|
                txn.execute(
 | 
						|
                    "CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
 | 
						|
                    " table_name varchar(100) NOT NULL UNIQUE,"
 | 
						|
                    " forward_rowid bigint NOT NULL,"
 | 
						|
                    " backward_rowid bigint NOT NULL"
 | 
						|
                    ")"
 | 
						|
                )
 | 
						|
 | 
						|
            # The old port script created a table with just a "rowid" column.
 | 
						|
            # We want people to be able to rerun this script from an old port
 | 
						|
            # so that they can pick up any missing events that were not
 | 
						|
            # ported across.
 | 
						|
            def alter_table(txn):
 | 
						|
                txn.execute(
 | 
						|
                    "ALTER TABLE IF EXISTS port_from_sqlite3"
 | 
						|
                    " RENAME rowid TO forward_rowid"
 | 
						|
                )
 | 
						|
                txn.execute(
 | 
						|
                    "ALTER TABLE IF EXISTS port_from_sqlite3"
 | 
						|
                    " ADD backward_rowid bigint NOT NULL DEFAULT 0"
 | 
						|
                )
 | 
						|
 | 
						|
            try:
 | 
						|
                yield self.postgres_store.runInteraction("alter_table", alter_table)
 | 
						|
            except Exception:
 | 
						|
                # On Error Resume Next
 | 
						|
                pass
 | 
						|
 | 
						|
            yield self.postgres_store.runInteraction(
 | 
						|
                "create_port_table", create_port_table
 | 
						|
            )
 | 
						|
 | 
						|
            # Step 2. Get tables.
 | 
						|
            self.progress.set_state("Fetching tables")
 | 
						|
            sqlite_tables = yield self.sqlite_store._simple_select_onecol(
 | 
						|
                table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
 | 
						|
            )
 | 
						|
 | 
						|
            postgres_tables = yield self.postgres_store._simple_select_onecol(
 | 
						|
                table="information_schema.tables",
 | 
						|
                keyvalues={},
 | 
						|
                retcol="distinct table_name",
 | 
						|
            )
 | 
						|
 | 
						|
            tables = set(sqlite_tables) & set(postgres_tables)
 | 
						|
            logger.info("Found %d tables", len(tables))
 | 
						|
 | 
						|
            # Step 3. Figure out what still needs copying
 | 
						|
            self.progress.set_state("Checking on port progress")
 | 
						|
            setup_res = yield defer.gatherResults(
 | 
						|
                [
 | 
						|
                    self.setup_table(table)
 | 
						|
                    for table in tables
 | 
						|
                    if table not in ["schema_version", "applied_schema_deltas"]
 | 
						|
                    and not table.startswith("sqlite_")
 | 
						|
                ],
 | 
						|
                consumeErrors=True,
 | 
						|
            )
 | 
						|
 | 
						|
            # Step 4. Do the copying.
 | 
						|
            self.progress.set_state("Copying to postgres")
 | 
						|
            yield defer.gatherResults(
 | 
						|
                [self.handle_table(*res) for res in setup_res], consumeErrors=True
 | 
						|
            )
 | 
						|
 | 
						|
            # Step 5. Do final post-processing
 | 
						|
            yield self._setup_state_group_id_seq()
 | 
						|
 | 
						|
            self.progress.done()
 | 
						|
        except Exception:
 | 
						|
            global end_error_exec_info
 | 
						|
            end_error_exec_info = sys.exc_info()
 | 
						|
            logger.exception("")
 | 
						|
        finally:
 | 
						|
            reactor.stop()
 | 
						|
 | 
						|
    def _convert_rows(self, table, headers, rows):
 | 
						|
        bool_col_names = BOOLEAN_COLUMNS.get(table, [])
 | 
						|
 | 
						|
        bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
 | 
						|
 | 
						|
        class BadValueException(Exception):
 | 
						|
            pass
 | 
						|
 | 
						|
        def conv(j, col):
 | 
						|
            if j in bool_cols:
 | 
						|
                return bool(col)
 | 
						|
            if isinstance(col, bytes):
 | 
						|
                return bytearray(col)
 | 
						|
            elif isinstance(col, string_types) and "\0" in col:
 | 
						|
                logger.warn(
 | 
						|
                    "DROPPING ROW: NUL value in table %s col %s: %r",
 | 
						|
                    table,
 | 
						|
                    headers[j],
 | 
						|
                    col,
 | 
						|
                )
 | 
						|
                raise BadValueException()
 | 
						|
            return col
 | 
						|
 | 
						|
        outrows = []
 | 
						|
        for i, row in enumerate(rows):
 | 
						|
            try:
 | 
						|
                outrows.append(
 | 
						|
                    tuple(conv(j, col) for j, col in enumerate(row) if j > 0)
 | 
						|
                )
 | 
						|
            except BadValueException:
 | 
						|
                pass
 | 
						|
 | 
						|
        return outrows
 | 
						|
 | 
						|
    @defer.inlineCallbacks
 | 
						|
    def _setup_sent_transactions(self):
 | 
						|
        # Only save things from the last day
 | 
						|
        yesterday = int(time.time() * 1000) - 86400000
 | 
						|
 | 
						|
        # And save the max transaction id from each destination
 | 
						|
        select = (
 | 
						|
            "SELECT rowid, * FROM sent_transactions WHERE rowid IN ("
 | 
						|
            "SELECT max(rowid) FROM sent_transactions"
 | 
						|
            " GROUP BY destination"
 | 
						|
            ")"
 | 
						|
        )
 | 
						|
 | 
						|
        def r(txn):
 | 
						|
            txn.execute(select)
 | 
						|
            rows = txn.fetchall()
 | 
						|
            headers = [column[0] for column in txn.description]
 | 
						|
 | 
						|
            ts_ind = headers.index('ts')
 | 
						|
 | 
						|
            return headers, [r for r in rows if r[ts_ind] < yesterday]
 | 
						|
 | 
						|
        headers, rows = yield self.sqlite_store.runInteraction("select", r)
 | 
						|
 | 
						|
        rows = self._convert_rows("sent_transactions", headers, rows)
 | 
						|
 | 
						|
        inserted_rows = len(rows)
 | 
						|
        if inserted_rows:
 | 
						|
            max_inserted_rowid = max(r[0] for r in rows)
 | 
						|
 | 
						|
            def insert(txn):
 | 
						|
                self.postgres_store.insert_many_txn(
 | 
						|
                    txn, "sent_transactions", headers[1:], rows
 | 
						|
                )
 | 
						|
 | 
						|
            yield self.postgres_store.execute(insert)
 | 
						|
        else:
 | 
						|
            max_inserted_rowid = 0
 | 
						|
 | 
						|
        def get_start_id(txn):
 | 
						|
            txn.execute(
 | 
						|
                "SELECT rowid FROM sent_transactions WHERE ts >= ?"
 | 
						|
                " ORDER BY rowid ASC LIMIT 1",
 | 
						|
                (yesterday,),
 | 
						|
            )
 | 
						|
 | 
						|
            rows = txn.fetchall()
 | 
						|
            if rows:
 | 
						|
                return rows[0][0]
 | 
						|
            else:
 | 
						|
                return 1
 | 
						|
 | 
						|
        next_chunk = yield self.sqlite_store.execute(get_start_id)
 | 
						|
        next_chunk = max(max_inserted_rowid + 1, next_chunk)
 | 
						|
 | 
						|
        yield self.postgres_store._simple_insert(
 | 
						|
            table="port_from_sqlite3",
 | 
						|
            values={
 | 
						|
                "table_name": "sent_transactions",
 | 
						|
                "forward_rowid": next_chunk,
 | 
						|
                "backward_rowid": 0,
 | 
						|
            },
 | 
						|
        )
 | 
						|
 | 
						|
        def get_sent_table_size(txn):
 | 
						|
            txn.execute(
 | 
						|
                "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
 | 
						|
            )
 | 
						|
            size, = txn.fetchone()
 | 
						|
            return int(size)
 | 
						|
 | 
						|
        remaining_count = yield self.sqlite_store.execute(get_sent_table_size)
 | 
						|
 | 
						|
        total_count = remaining_count + inserted_rows
 | 
						|
 | 
						|
        defer.returnValue((next_chunk, inserted_rows, total_count))
 | 
						|
 | 
						|
    @defer.inlineCallbacks
 | 
						|
    def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
 | 
						|
        frows = yield self.sqlite_store.execute_sql(
 | 
						|
            "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
 | 
						|
        )
 | 
						|
 | 
						|
        brows = yield self.sqlite_store.execute_sql(
 | 
						|
            "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
 | 
						|
        )
 | 
						|
 | 
						|
        defer.returnValue(frows[0][0] + brows[0][0])
 | 
						|
 | 
						|
    @defer.inlineCallbacks
 | 
						|
    def _get_already_ported_count(self, table):
 | 
						|
        rows = yield self.postgres_store.execute_sql(
 | 
						|
            "SELECT count(*) FROM %s" % (table,)
 | 
						|
        )
 | 
						|
 | 
						|
        defer.returnValue(rows[0][0])
 | 
						|
 | 
						|
    @defer.inlineCallbacks
 | 
						|
    def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
 | 
						|
        remaining, done = yield defer.gatherResults(
 | 
						|
            [
 | 
						|
                self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
 | 
						|
                self._get_already_ported_count(table),
 | 
						|
            ],
 | 
						|
            consumeErrors=True,
 | 
						|
        )
 | 
						|
 | 
						|
        remaining = int(remaining) if remaining else 0
 | 
						|
        done = int(done) if done else 0
 | 
						|
 | 
						|
        defer.returnValue((done, remaining + done))
 | 
						|
 | 
						|
    def _setup_state_group_id_seq(self):
 | 
						|
        def r(txn):
 | 
						|
            txn.execute("SELECT MAX(id) FROM state_groups")
 | 
						|
            next_id = txn.fetchone()[0] + 1
 | 
						|
            txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
 | 
						|
 | 
						|
        return self.postgres_store.runInteraction("setup_state_group_id_seq", r)
 | 
						|
 | 
						|
 | 
						|
##############################################
 | 
						|
# The following is simply UI stuff
 | 
						|
##############################################
 | 
						|
 | 
						|
 | 
						|
class Progress(object):
 | 
						|
    """Used to report progress of the port
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self):
 | 
						|
        self.tables = {}
 | 
						|
 | 
						|
        self.start_time = int(time.time())
 | 
						|
 | 
						|
    def add_table(self, table, cur, size):
 | 
						|
        self.tables[table] = {
 | 
						|
            "start": cur,
 | 
						|
            "num_done": cur,
 | 
						|
            "total": size,
 | 
						|
            "perc": int(cur * 100 / size),
 | 
						|
        }
 | 
						|
 | 
						|
    def update(self, table, num_done):
 | 
						|
        data = self.tables[table]
 | 
						|
        data["num_done"] = num_done
 | 
						|
        data["perc"] = int(num_done * 100 / data["total"])
 | 
						|
 | 
						|
    def done(self):
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
class CursesProgress(Progress):
 | 
						|
    """Reports progress to a curses window
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, stdscr):
 | 
						|
        self.stdscr = stdscr
 | 
						|
 | 
						|
        curses.use_default_colors()
 | 
						|
        curses.curs_set(0)
 | 
						|
 | 
						|
        curses.init_pair(1, curses.COLOR_RED, -1)
 | 
						|
        curses.init_pair(2, curses.COLOR_GREEN, -1)
 | 
						|
 | 
						|
        self.last_update = 0
 | 
						|
 | 
						|
        self.finished = False
 | 
						|
 | 
						|
        self.total_processed = 0
 | 
						|
        self.total_remaining = 0
 | 
						|
 | 
						|
        super(CursesProgress, self).__init__()
 | 
						|
 | 
						|
    def update(self, table, num_done):
 | 
						|
        super(CursesProgress, self).update(table, num_done)
 | 
						|
 | 
						|
        self.total_processed = 0
 | 
						|
        self.total_remaining = 0
 | 
						|
        for table, data in self.tables.items():
 | 
						|
            self.total_processed += data["num_done"] - data["start"]
 | 
						|
            self.total_remaining += data["total"] - data["num_done"]
 | 
						|
 | 
						|
        self.render()
 | 
						|
 | 
						|
    def render(self, force=False):
 | 
						|
        now = time.time()
 | 
						|
 | 
						|
        if not force and now - self.last_update < 0.2:
 | 
						|
            # reactor.callLater(1, self.render)
 | 
						|
            return
 | 
						|
 | 
						|
        self.stdscr.clear()
 | 
						|
 | 
						|
        rows, cols = self.stdscr.getmaxyx()
 | 
						|
 | 
						|
        duration = int(now) - int(self.start_time)
 | 
						|
 | 
						|
        minutes, seconds = divmod(duration, 60)
 | 
						|
        duration_str = '%02dm %02ds' % (minutes, seconds)
 | 
						|
 | 
						|
        if self.finished:
 | 
						|
            status = "Time spent: %s (Done!)" % (duration_str,)
 | 
						|
        else:
 | 
						|
 | 
						|
            if self.total_processed > 0:
 | 
						|
                left = float(self.total_remaining) / self.total_processed
 | 
						|
 | 
						|
                est_remaining = (int(now) - self.start_time) * left
 | 
						|
                est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
 | 
						|
            else:
 | 
						|
                est_remaining_str = "Unknown"
 | 
						|
            status = "Time spent: %s (est. remaining: %s)" % (
 | 
						|
                duration_str,
 | 
						|
                est_remaining_str,
 | 
						|
            )
 | 
						|
 | 
						|
        self.stdscr.addstr(0, 0, status, curses.A_BOLD)
 | 
						|
 | 
						|
        max_len = max([len(t) for t in self.tables.keys()])
 | 
						|
 | 
						|
        left_margin = 5
 | 
						|
        middle_space = 1
 | 
						|
 | 
						|
        items = self.tables.items()
 | 
						|
        items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
 | 
						|
 | 
						|
        for i, (table, data) in enumerate(items):
 | 
						|
            if i + 2 >= rows:
 | 
						|
                break
 | 
						|
 | 
						|
            perc = data["perc"]
 | 
						|
 | 
						|
            color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
 | 
						|
 | 
						|
            self.stdscr.addstr(
 | 
						|
                i + 2, left_margin + max_len - len(table), table, curses.A_BOLD | color
 | 
						|
            )
 | 
						|
 | 
						|
            size = 20
 | 
						|
 | 
						|
            progress = "[%s%s]" % (
 | 
						|
                "#" * int(perc * size / 100),
 | 
						|
                " " * (size - int(perc * size / 100)),
 | 
						|
            )
 | 
						|
 | 
						|
            self.stdscr.addstr(
 | 
						|
                i + 2,
 | 
						|
                left_margin + max_len + middle_space,
 | 
						|
                "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
 | 
						|
            )
 | 
						|
 | 
						|
        if self.finished:
 | 
						|
            self.stdscr.addstr(rows - 1, 0, "Press any key to exit...")
 | 
						|
 | 
						|
        self.stdscr.refresh()
 | 
						|
        self.last_update = time.time()
 | 
						|
 | 
						|
    def done(self):
 | 
						|
        self.finished = True
 | 
						|
        self.render(True)
 | 
						|
        self.stdscr.getch()
 | 
						|
 | 
						|
    def set_state(self, state):
 | 
						|
        self.stdscr.clear()
 | 
						|
        self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
 | 
						|
        self.stdscr.refresh()
 | 
						|
 | 
						|
 | 
						|
class TerminalProgress(Progress):
 | 
						|
    """Just prints progress to the terminal
 | 
						|
    """
 | 
						|
 | 
						|
    def update(self, table, num_done):
 | 
						|
        super(TerminalProgress, self).update(table, num_done)
 | 
						|
 | 
						|
        data = self.tables[table]
 | 
						|
 | 
						|
        print(
 | 
						|
            "%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
 | 
						|
        )
 | 
						|
 | 
						|
    def set_state(self, state):
 | 
						|
        print(state + "...")
 | 
						|
 | 
						|
 | 
						|
##############################################
 | 
						|
##############################################
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    parser = argparse.ArgumentParser(
 | 
						|
        description="A script to port an existing synapse SQLite database to"
 | 
						|
        " a new PostgreSQL database."
 | 
						|
    )
 | 
						|
    parser.add_argument("-v", action='store_true')
 | 
						|
    parser.add_argument(
 | 
						|
        "--sqlite-database",
 | 
						|
        required=True,
 | 
						|
        help="The snapshot of the SQLite database file. This must not be"
 | 
						|
        " currently used by a running synapse server",
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--postgres-config",
 | 
						|
        type=argparse.FileType('r'),
 | 
						|
        required=True,
 | 
						|
        help="The database config file for the PostgreSQL database",
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--curses", action='store_true', help="display a curses based progress UI"
 | 
						|
    )
 | 
						|
 | 
						|
    parser.add_argument(
 | 
						|
        "--batch-size",
 | 
						|
        type=int,
 | 
						|
        default=1000,
 | 
						|
        help="The number of rows to select from the SQLite table each"
 | 
						|
        " iteration [default=1000]",
 | 
						|
    )
 | 
						|
 | 
						|
    args = parser.parse_args()
 | 
						|
 | 
						|
    logging_config = {
 | 
						|
        "level": logging.DEBUG if args.v else logging.INFO,
 | 
						|
        "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
 | 
						|
    }
 | 
						|
 | 
						|
    if args.curses:
 | 
						|
        logging_config["filename"] = "port-synapse.log"
 | 
						|
 | 
						|
    logging.basicConfig(**logging_config)
 | 
						|
 | 
						|
    sqlite_config = {
 | 
						|
        "name": "sqlite3",
 | 
						|
        "args": {
 | 
						|
            "database": args.sqlite_database,
 | 
						|
            "cp_min": 1,
 | 
						|
            "cp_max": 1,
 | 
						|
            "check_same_thread": False,
 | 
						|
        },
 | 
						|
    }
 | 
						|
 | 
						|
    hs_config = yaml.safe_load(args.postgres_config)
 | 
						|
 | 
						|
    if "database" not in hs_config:
 | 
						|
        sys.stderr.write("The configuration file must have a 'database' section.\n")
 | 
						|
        sys.exit(4)
 | 
						|
 | 
						|
    postgres_config = hs_config["database"]
 | 
						|
 | 
						|
    if "name" not in postgres_config:
 | 
						|
        sys.stderr.write("Malformed database config: no 'name'\n")
 | 
						|
        sys.exit(2)
 | 
						|
    if postgres_config["name"] != "psycopg2":
 | 
						|
        sys.stderr.write("Database must use the 'psycopg2' connector.\n")
 | 
						|
        sys.exit(3)
 | 
						|
 | 
						|
    config = HomeServerConfig()
 | 
						|
    config.parse_config_dict(hs_config, "", "")
 | 
						|
 | 
						|
    def start(stdscr=None):
 | 
						|
        if stdscr:
 | 
						|
            progress = CursesProgress(stdscr)
 | 
						|
        else:
 | 
						|
            progress = TerminalProgress()
 | 
						|
 | 
						|
        porter = Porter(
 | 
						|
            sqlite_config=sqlite_config,
 | 
						|
            progress=progress,
 | 
						|
            batch_size=args.batch_size,
 | 
						|
            hs_config=config,
 | 
						|
        )
 | 
						|
 | 
						|
        reactor.callWhenRunning(porter.run)
 | 
						|
 | 
						|
        reactor.run()
 | 
						|
 | 
						|
    if args.curses:
 | 
						|
        curses.wrapper(start)
 | 
						|
    else:
 | 
						|
        start()
 | 
						|
 | 
						|
    if end_error_exec_info:
 | 
						|
        exc_type, exc_value, exc_traceback = end_error_exec_info
 | 
						|
        traceback.print_exception(exc_type, exc_value, exc_traceback)
 |