# -*- coding: utf-8 -*- # Copyright 2014, 2015 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from synapse.api.errors import StoreError from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.descriptors import Cache import synapse.metrics from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer from collections import namedtuple import sys import time import threading logger = logging.getLogger(__name__) sql_logger = logging.getLogger("synapse.storage.SQL") transaction_logger = logging.getLogger("synapse.storage.txn") perf_logger = logging.getLogger("synapse.storage.TIME") metrics = synapse.metrics.get_metrics_for("synapse.storage") sql_scheduling_timer = metrics.register_distribution("schedule_time") sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method.""" __slots__ = ["txn", "name", "database_engine", "after_callbacks"] def __init__(self, txn, name, database_engine, after_callbacks): object.__setattr__(self, "txn", txn) object.__setattr__(self, "name", name) object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "after_callbacks", after_callbacks) def call_after(self, callback, *args): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. """ self.after_callbacks.append((callback, args)) def __getattr__(self, name): return getattr(self.txn, name) def __setattr__(self, name, value): setattr(self.txn, name, value) def execute(self, sql, *args): self._do_execute(self.txn.execute, sql, *args) def executemany(self, sql, *args): self._do_execute(self.txn.executemany, sql, *args) def _do_execute(self, func, sql, *args): # TODO(paul): Maybe use 'info' and 'debug' for values? sql_logger.debug("[SQL] {%s} %s", self.name, sql) sql = self.database_engine.convert_param_style(sql) if args: try: sql_logger.debug( "[SQL values] {%s} %r", self.name, args[0] ) except: # Don't let logging failures stop SQL from working pass start = time.time() * 1000 try: return func( sql, *args ) except Exception as e: logger.debug("[SQL FAIL] {%s} %s", self.name, e) raise finally: msecs = (time.time() * 1000) - start sql_logger.debug("[SQL time] {%s} %f", self.name, msecs) sql_query_timer.inc_by(msecs, sql.split()[0]) class PerformanceCounters(object): def __init__(self): self.current_counters = {} self.previous_counters = {} def update(self, key, start_time, end_time=None): if end_time is None: end_time = time.time() * 1000 duration = end_time - start_time count, cum_time = self.current_counters.get(key, (0, 0)) count += 1 cum_time += duration self.current_counters[key] = (count, cum_time) return end_time def interval(self, interval_duration, limit=3): counters = [] for name, (count, cum_time) in self.current_counters.items(): prev_count, prev_time = self.previous_counters.get(name, (0, 0)) counters.append(( (cum_time - prev_time) / interval_duration, count - prev_count, name )) self.previous_counters = dict(self.current_counters) counters.sort(reverse=True) top_n_counters = ", ".join( "%s(%d): %.3f%%" % (name, count, 100 * ratio) for ratio, count, name in counters[:limit] ) return top_n_counters class SQLBaseStore(object): _TXN_ID = 0 def __init__(self, hs): self.hs = hs self._db_pool = hs.get_db_pool() self._clock = hs.get_clock() self._previous_txn_total_time = 0 self._current_txn_total_time = 0 self._previous_loop_ts = 0 # TODO(paul): These can eventually be removed once the metrics code # is running in mainline, and we have some nice monitoring frontends # to watch it self._txn_perf_counters = PerformanceCounters() self._get_event_counters = PerformanceCounters() self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) self._state_group_cache = DictionaryCache("*stateGroupCache*", 2000) self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 self._pending_ds = [] self.database_engine = hs.database_engine self._stream_id_gen = StreamIdGenerator("events", "stream_ordering") self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id") def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() def loop(): curr = self._current_txn_total_time prev = self._previous_txn_total_time self._previous_txn_total_time = curr time_now = self._clock.time_msec() time_then = self._previous_loop_ts self._previous_loop_ts = time_now ratio = (curr - prev)/(time_now - time_then) top_three_counters = self._txn_perf_counters.interval( time_now - time_then, limit=3 ) top_3_event_counters = self._get_event_counters.interval( time_now - time_then, limit=3 ) perf_logger.info( "Total database time: %.3f%% {%s} {%s}", ratio * 100, top_three_counters, top_3_event_counters ) self._clock.looping_call(loop, 10000) def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs): start = time.time() * 1000 txn_id = self._TXN_ID # We don't really need these to be unique, so lets stop it from # growing really large. self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) name = "%s-%x" % (desc, txn_id, ) transaction_logger.debug("[TXN START] {%s}", name) try: i = 0 N = 5 while True: try: txn = conn.cursor() txn = LoggingTransaction( txn, name, self.database_engine, after_callbacks ) r = func(txn, *args, **kwargs) conn.commit() return r except self.database_engine.module.OperationalError as e: # This can happen if the database disappears mid # transaction. logger.warn( "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N ) if i < N: i += 1 try: conn.rollback() except self.database_engine.module.Error as e1: logger.warn( "[TXN EROLL] {%s} %s", name, e1, ) continue raise except self.database_engine.module.DatabaseError as e: if self.database_engine.is_deadlock(e): logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N) if i < N: i += 1 try: conn.rollback() except self.database_engine.module.Error as e1: logger.warn( "[TXN EROLL] {%s} %s", name, e1, ) continue raise except Exception as e: logger.debug("[TXN FAIL] {%s} %s", name, e) raise finally: end = time.time() * 1000 duration = end - start transaction_logger.debug("[TXN END] {%s} %f", name, duration) self._current_txn_total_time += duration self._txn_perf_counters.update(desc, start, end) sql_txn_timer.inc_by(duration, desc) @defer.inlineCallbacks def runInteraction(self, desc, func, *args, **kwargs): """Wraps the .runInteraction() method on the underlying db_pool.""" current_context = LoggingContext.current_context() start_time = time.time() * 1000 after_callbacks = [] def inner_func(conn, *args, **kwargs): with LoggingContext("runInteraction") as context: sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) if self.database_engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") conn.reconnect() current_context.copy_to(context) return self._new_transaction( conn, desc, after_callbacks, func, *args, **kwargs ) result = yield preserve_context_over_fn( self._db_pool.runWithConnection, inner_func, *args, **kwargs ) for after_callback, after_args in after_callbacks: after_callback(*after_args) defer.returnValue(result) @defer.inlineCallbacks def runWithConnection(self, func, *args, **kwargs): """Wraps the .runInteraction() method on the underlying db_pool.""" current_context = LoggingContext.current_context() start_time = time.time() * 1000 def inner_func(conn, *args, **kwargs): with LoggingContext("runWithConnection") as context: sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) if self.database_engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") conn.reconnect() current_context.copy_to(context) return func(conn, *args, **kwargs) result = yield preserve_context_over_fn( self._db_pool.runWithConnection, inner_func, *args, **kwargs ) defer.returnValue(result) def cursor_to_dict(self, cursor): """Converts a SQL cursor into an list of dicts. Args: cursor : The DBAPI cursor which has executed a query. Returns: A list of dicts where the key is the column header. """ col_headers = list(column[0] for column in cursor.description) results = list( dict(zip(col_headers, row)) for row in cursor.fetchall() ) return results def _execute(self, desc, decoder, query, *args): """Runs a single query for a result set. Args: decoder - The function which can resolve the cursor results to something meaningful. query - The query string to execute *args - Query args. Returns: The result of decoder(results) """ def interaction(txn): txn.execute(query, args) if decoder: return decoder(txn) else: return txn.fetchall() return self.runInteraction(desc, interaction) def _execute_and_decode(self, desc, query, *args): return self._execute(desc, self.cursor_to_dict, query, *args) # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. @defer.inlineCallbacks def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"): """Executes an INSERT query on the named table. Args: table : string giving the table name values : dict of new column names and values for them """ try: yield self.runInteraction( desc, self._simple_insert_txn, table, values, ) except self.database_engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. if not or_ignore: raise @log_function def _simple_insert_txn(self, txn, table, values): keys, vals = zip(*values.items()) sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, ", ".join(k for k in keys), ", ".join("?" for _ in keys) ) txn.execute(sql, vals) def _simple_insert_many_txn(self, txn, table, values): if not values: return # This is a *slight* abomination to get a list of tuples of key names # and a list of tuples of value names. # # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] # # The sort is to ensure that we don't rely on dictionary iteration # order. keys, vals = zip(*[ zip( *(sorted(i.items(), key=lambda kv: kv[0])) ) for i in values if i ]) for k in keys: if k != keys[0]: raise RuntimeError( "All items must have the same keys" ) sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, ", ".join(k for k in keys[0]), ", ".join("?" for _ in keys[0]) ) txn.executemany(sql, vals) def _simple_upsert(self, table, keyvalues, values, insertion_values={}, desc="_simple_upsert", lock=True): """ Args: table (str): The table to upsert into keyvalues (dict): The unique key tables and their new values values (dict): The nonunique columns and their new values insertion_values (dict): key/values to use when inserting Returns: A deferred """ return self.runInteraction( desc, self._simple_upsert_txn, table, keyvalues, values, insertion_values, lock ) def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}, lock=True): # We need to lock the table :(, unless we're *really* careful if lock: self.database_engine.lock_table(txn, table) # Try to update sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in values), " AND ".join("%s = ?" % (k,) for k in keyvalues) ) sqlargs = values.values() + keyvalues.values() logger.debug( "[SQL] %s Args=%s", sql, sqlargs, ) txn.execute(sql, sqlargs) if txn.rowcount == 0: # We didn't update and rows so insert a new one allvalues = {} allvalues.update(keyvalues) allvalues.update(values) allvalues.update(insertion_values) sql = "INSERT INTO %s (%s) VALUES (%s)" % ( table, ", ".join(k for k in allvalues), ", ".join("?" for _ in allvalues) ) logger.debug( "[SQL] %s Args=%s", sql, keyvalues.values(), ) txn.execute(sql, allvalues.values()) def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. Args: table : string giving the table name keyvalues : dict of column names and values to select the row with retcols : list of strings giving the names of the columns to return allow_none : If true, return None instead of failing if the SELECT statement returns no rows """ return self.runInteraction( desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none, ) def _simple_select_one_onecol(self, table, keyvalues, retcol, allow_none=False, desc="_simple_select_one_onecol"): """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it." Args: table : string giving the table name keyvalues : dict of column names and values to select the row with retcol : string giving the name of the column to return """ return self.runInteraction( desc, self._simple_select_one_onecol_txn, table, keyvalues, retcol, allow_none=allow_none, ) def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol, allow_none=False): ret = self._simple_select_onecol_txn( txn, table=table, keyvalues=keyvalues, retcol=retcol, ) if ret: return ret[0] else: if allow_none: return None else: raise StoreError(404, "No row found") def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): sql = ( "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" ) % { "retcol": retcol, "table": table, "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), } txn.execute(sql, keyvalues.values()) return [r[0] for r in txn.fetchall()] def _simple_select_onecol(self, table, keyvalues, retcol, desc="_simple_select_onecol"): """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. Args: table (str): table name keyvalues (dict): column names and values to select the rows with retcol (str): column whos value we wish to retrieve. Returns: Deferred: Results in a list """ return self.runInteraction( desc, self._simple_select_onecol_txn, table, keyvalues, retcol ) def _simple_select_list(self, table, keyvalues, retcols, desc="_simple_select_list"): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. Args: table : string giving the table name keyvalues : dict of column names and values to select the rows with, or None to not apply a WHERE clause. retcols : list of strings giving the names of the columns to return """ return self.runInteraction( desc, self._simple_select_list_txn, table, keyvalues, retcols ) def _simple_select_list_txn(self, txn, table, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. Args: txn : Transaction object table : string giving the table name keyvalues : dict of column names and values to select the rows with retcols : list of strings giving the names of the columns to return """ if keyvalues: sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) txn.execute(sql, keyvalues.values()) else: sql = "SELECT %s FROM %s" % ( ", ".join(retcols), table ) txn.execute(sql) return self.cursor_to_dict(txn) def _simple_update_one(self, table, keyvalues, updatevalues, desc="_simple_update_one"): """Executes an UPDATE query on the named table, setting new values for columns in a row matching the key values. Args: table : string giving the table name keyvalues : dict of column names and values to select the row with updatevalues : dict giving column names and values to update retcols : optional list of column names to return If present, retcols gives a list of column names on which to perform a SELECT statement *before* performing the UPDATE statement. The values of these will be returned in a dict. These are performed within the same transaction, allowing an atomic get-and-set. This can be used to implement compare-and-set by putting the update column in the 'keyvalues' dict as well. """ return self.runInteraction( desc, self._simple_update_one_txn, table, keyvalues, updatevalues, ) def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues): update_sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in updatevalues), " AND ".join("%s = ?" % (k,) for k in keyvalues) ) txn.execute( update_sql, updatevalues.values() + keyvalues.values() ) if txn.rowcount == 0: raise StoreError(404, "No row found") if txn.rowcount > 1: raise StoreError(500, "More than one row matched") def _simple_select_one_txn(self, txn, table, keyvalues, retcols, allow_none=False): select_sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, " AND ".join("%s = ?" % (k,) for k in keyvalues) ) txn.execute(select_sql, keyvalues.values()) row = txn.fetchone() if not row: if allow_none: return None raise StoreError(404, "No row found") if txn.rowcount > 1: raise StoreError(500, "More than one row matched") return dict(zip(retcols, row)) def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None, retcols=None, allow_none=False, desc="_simple_selectupdate_one"): """ Combined SELECT then UPDATE.""" def func(txn): ret = None if retcols: ret = self._simple_select_one_txn( txn, table=table, keyvalues=keyvalues, retcols=retcols, allow_none=allow_none, ) if updatevalues: self._simple_update_one_txn( txn, table=table, keyvalues=keyvalues, updatevalues=updatevalues, ) # if txn.rowcount == 0: # raise StoreError(404, "No row found") if txn.rowcount > 1: raise StoreError(500, "More than one row matched") return ret return self.runInteraction(desc, func) def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): """Executes a DELETE query on the named table, expecting to delete a single row. Args: table : string giving the table name keyvalues : dict of column names and values to select the row with """ sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) def func(txn): txn.execute(sql, keyvalues.values()) if txn.rowcount == 0: raise StoreError(404, "No row found") if txn.rowcount > 1: raise StoreError(500, "more than one row matched") return self.runInteraction(desc, func) def _simple_delete(self, table, keyvalues, desc="_simple_delete"): """Executes a DELETE query on the named table. Args: table : string giving the table name keyvalues : dict of column names and values to select the row with """ return self.runInteraction(desc, self._simple_delete_txn) def _simple_delete_txn(self, txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) return txn.execute(sql, keyvalues.values()) def _simple_max_id(self, table): """Executes a SELECT query on the named table, expecting to return the max value for the column "id". Args: table : string giving the table name """ sql = "SELECT MAX(id) AS id FROM %s" % table def func(txn): txn.execute(sql) max_id = self.cursor_to_dict(txn)[0]["id"] if max_id is None: return 0 return max_id return self.runInteraction("_simple_max_id", func) def get_next_stream_id(self): with self._next_stream_id_lock: i = self._next_stream_id self._next_stream_id += 1 return i class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying something went wrong. """ pass class Table(object): """ A base class used to store information about a particular table. """ table_name = None """ str: The name of the table """ fields = None """ list: The field names """ EntryType = None """ Type: A tuple type used to decode the results """ _select_where_clause = "SELECT %s FROM %s WHERE %s" _select_clause = "SELECT %s FROM %s" _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)" @classmethod def select_statement(cls, where_clause=None): """ Args: where_clause (str): The WHERE clause to use. Returns: str: An SQL statement to select rows from the table with the given WHERE clause. """ if where_clause: return cls._select_where_clause % ( ", ".join(cls.fields), cls.table_name, where_clause ) else: return cls._select_clause % ( ", ".join(cls.fields), cls.table_name, ) @classmethod def insert_statement(cls): return cls._insert_clause % ( cls.table_name, ", ".join(cls.fields), ", ".join(["?"] * len(cls.fields)), ) @classmethod def decode_single_result(cls, results): """ Given an iterable of tuples, return a single instance of `EntryType` or None if the iterable is empty Args: results (list): The results list to convert to `EntryType` Returns: EntryType: An instance of `EntryType` """ results = list(results) if results: return cls.EntryType(*results[0]) else: return None @classmethod def decode_results(cls, results): """ Given an iterable of tuples, return a list of `EntryType` Args: results (list): The results list to convert to `EntryType` Returns: list: A list of `EntryType` """ return [cls.EntryType(*row) for row in results] @classmethod def get_fields_string(cls, prefix=None): if prefix: to_join = ("%s.%s" % (prefix, f) for f in cls.fields) else: to_join = cls.fields return ", ".join(to_join) class JoinHelper(object): """ Used to help do joins on tables by looking at the tables' fields and creating a list of unique fields to use with SELECTs and a namedtuple to dump the results into. Attributes: tables (list): List of `Table` classes EntryType (type) """ def __init__(self, *tables): self.tables = tables res = [] for table in self.tables: res += [f for f in table.fields if f not in res] self.EntryType = namedtuple("JoinHelperEntry", res) def get_fields(self, **prefixes): """Get a string representing a list of fields for use in SELECT statements with the given prefixes applied to each. For example:: JoinHelper(PdusTable, StateTable).get_fields( PdusTable="pdus", StateTable="state" ) """ res = [] for field in self.EntryType._fields: for table in self.tables: if field in table.fields: res.append("%s.%s" % (prefixes[table.__name__], field)) break return ", ".join(res) def decode_results(self, rows): return [self.EntryType(*row) for row in rows]