Make database selection configurable
parent
0d0610870d
commit
455579ca90
|
@ -61,6 +61,7 @@ import resource
|
|||
import subprocess
|
||||
import sqlite3
|
||||
import syweb
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -108,15 +109,15 @@ class SynapseHomeServer(HomeServer):
|
|||
return None
|
||||
|
||||
def build_db_pool(self):
|
||||
name = self.db_config.pop("name", None)
|
||||
if name == "MySQLdb":
|
||||
return adbapi.ConnectionPool(
|
||||
"sqlite3", self.get_db_name(),
|
||||
check_same_thread=False,
|
||||
cp_min=1,
|
||||
cp_max=1,
|
||||
cp_openfun=prepare_database, # Prepare the database for each conn
|
||||
# so that :memory: sqlite works
|
||||
name,
|
||||
**self.db_config
|
||||
)
|
||||
|
||||
raise RuntimeError("Unsupported database type")
|
||||
|
||||
def create_resource_tree(self, redirect_root_to_web_client):
|
||||
"""Create the resource tree for this Home Server.
|
||||
|
||||
|
@ -357,11 +358,29 @@ def setup(config_options):
|
|||
|
||||
tls_context_factory = context_factory.ServerContextFactory(config)
|
||||
|
||||
if config.database_config:
|
||||
with open(config.database_config, 'r') as f:
|
||||
db_config = yaml.safe_load(f)
|
||||
|
||||
name = db_config.get("name", None)
|
||||
if name == "MySQLdb":
|
||||
db_config.update({
|
||||
"sql_mode": "TRADITIONAL",
|
||||
"charset": "utf8",
|
||||
"use_unicode": True,
|
||||
})
|
||||
else:
|
||||
db_config = {
|
||||
"name": "sqlite3",
|
||||
"database": config.database_path,
|
||||
}
|
||||
|
||||
hs = SynapseHomeServer(
|
||||
config.server_name,
|
||||
domain_with_port=domain_with_port,
|
||||
upload_dir=os.path.abspath("uploads"),
|
||||
db_name=config.database_path,
|
||||
db_config=db_config,
|
||||
tls_context_factory=tls_context_factory,
|
||||
config=config,
|
||||
content_addr=config.content_addr,
|
||||
|
@ -377,8 +396,11 @@ def setup(config_options):
|
|||
logger.info("Preparing database: %s...", db_name)
|
||||
|
||||
try:
|
||||
with sqlite3.connect(db_name) as db_conn:
|
||||
prepare_sqlite3_database(db_conn)
|
||||
# with sqlite3.connect(db_name) as db_conn:
|
||||
# prepare_sqlite3_database(db_conn)
|
||||
# prepare_database(db_conn)
|
||||
import MySQLdb
|
||||
db_conn = MySQLdb.connect(**db_config)
|
||||
prepare_database(db_conn)
|
||||
except UpgradeDatabaseException:
|
||||
sys.stderr.write(
|
||||
|
|
|
@ -26,6 +26,11 @@ class DatabaseConfig(Config):
|
|||
self.database_path = self.abspath(args.database_path)
|
||||
self.event_cache_size = self.parse_size(args.event_cache_size)
|
||||
|
||||
if args.database_config:
|
||||
self.database_config = self.abspath(args.database_config)
|
||||
else:
|
||||
self.database_config = None
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(DatabaseConfig, cls).add_arguments(parser)
|
||||
|
@ -38,6 +43,10 @@ class DatabaseConfig(Config):
|
|||
"--event-cache-size", default="100K",
|
||||
help="Number of events to cache in memory."
|
||||
)
|
||||
db_group.add_argument(
|
||||
"--database-config", default=None,
|
||||
help="Location of the database configuration file."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_config(cls, args, config_dir_path):
|
||||
|
|
Loading…
Reference in New Issue