prepare_database() on db_conn, not plain name, so we can pass in the connection from outside
parent
2faffc52ee
commit
55397f6347
|
@ -39,6 +39,7 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import sqlite3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -208,7 +209,14 @@ def setup():
|
|||
redirect_root_to_web_client=True,
|
||||
)
|
||||
|
||||
prepare_database(hs.get_db_name())
|
||||
db_name = hs.get_db_name()
|
||||
|
||||
logging.info("Preparing database: %s...", db_name)
|
||||
|
||||
with sqlite3.connect(db_name) as db_conn:
|
||||
prepare_database(db_conn)
|
||||
|
||||
logging.info("Database prepared in %s.", db_name)
|
||||
|
||||
hs.get_db_pool()
|
||||
|
||||
|
|
|
@ -43,7 +43,6 @@ from .keys import KeyStore
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -370,13 +369,10 @@ def read_schema(schema):
|
|||
return schema_file.read()
|
||||
|
||||
|
||||
def prepare_database(db_name):
|
||||
def prepare_database(db_conn):
|
||||
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
|
||||
don't have to worry about overwriting existing content.
|
||||
"""
|
||||
logging.info("Preparing database: %s...", db_name)
|
||||
|
||||
with sqlite3.connect(db_name) as db_conn:
|
||||
c = db_conn.cursor()
|
||||
c.execute("PRAGMA user_version")
|
||||
row = c.fetchone()
|
||||
|
@ -410,4 +406,3 @@ def prepare_database(db_name):
|
|||
|
||||
c.close()
|
||||
|
||||
logging.info("Database prepared in %s.", db_name)
|
||||
|
|
Loading…
Reference in New Issue