From a33a5abc4cd0a73fdf851850da21d749d842ae32 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 5 Apr 2019 00:21:16 +1100 Subject: [PATCH] Clean up the database pagination code (#5007) * rewrite & simplify * changelog * cleanup potential sql injection --- changelog.d/5007.misc | 1 + synapse/storage/__init__.py | 20 ++++--- synapse/storage/_base.py | 112 ++++++++++++++++-------------------- 3 files changed, 65 insertions(+), 68 deletions(-) create mode 100644 changelog.d/5007.misc diff --git a/changelog.d/5007.misc b/changelog.d/5007.misc new file mode 100644 index 0000000000..05b6ce2c26 --- /dev/null +++ b/changelog.d/5007.misc @@ -0,0 +1 @@ +Refactor synapse.storage._base._simple_select_list_paginate. \ No newline at end of file diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e9aa2fc9dd..c432041b4e 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -18,6 +18,8 @@ import calendar import logging import time +from twisted.internet import defer + from synapse.api.constants import PresenceState from synapse.storage.devices import DeviceStore from synapse.storage.user_erasure_store import UserErasureStore @@ -453,6 +455,7 @@ class DataStore( desc="get_users", ) + @defer.inlineCallbacks def get_users_paginate(self, order, start, limit): """Function to reterive a paginated list of users from users list. This will return a json object, which contains @@ -465,16 +468,19 @@ class DataStore( Returns: defer.Deferred: resolves to json object {list[dict[str, Any]], count} """ - is_guest = 0 - i_start = (int)(start) - i_limit = (int)(limit) - return self.get_user_list_paginate( + users = yield self.runInteraction( + "get_users_paginate", + self._simple_select_list_paginate_txn, table="users", - keyvalues={"is_guest": is_guest}, - pagevalues=[order, i_limit, i_start], + keyvalues={"is_guest": False}, + orderby=order, + start=start, + limit=limit, retcols=["name", "password_hash", "is_guest", "admin"], - desc="get_users_paginate", ) + count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn) + retval = {"users": users, "total": count} + defer.returnValue(retval) def search_users(self, term): """Function to search users list for one or more users with diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 131820628a..983ce026e1 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -595,7 +595,7 @@ class SQLBaseStore(object): Args: table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values + keyvalues (dict): The unique key columns and their new values values (dict): The nonunique columns and their new values insertion_values (dict): additional key/values to use only when inserting @@ -627,7 +627,7 @@ class SQLBaseStore(object): # presumably we raced with another transaction: let's retry. logger.warn( - "%s when upserting into %s; retrying: %s", e.__name__, table, e + "IntegrityError when upserting into %s; retrying: %s", table, e ) def _simple_upsert_txn( @@ -1398,21 +1398,31 @@ class SQLBaseStore(object): return 0 def _simple_select_list_paginate( - self, table, keyvalues, pagevalues, retcols, desc="_simple_select_list_paginate" + self, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction="ASC", + desc="_simple_select_list_paginate", ): - """Executes a SELECT query on the named table with start and limit, + """ + Executes a SELECT query on the named table with start and limit, of row numbers, which may return zero or number of rows from start to limit, returning the result as a list of dicts. Args: table (str): the table name - keyvalues (dict[str, Any] | None): + keyvalues (dict[str, T] | None): column names and values to select the rows with, or None to not apply a WHERE clause. + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. retcols (iterable[str]): the names of the columns to return - order (str): order the select by this column - start (int): start number to begin the query from - limit (int): number of rows to reterive + order_direction (str): Whether the results should be ordered "ASC" or "DESC". Returns: defer.Deferred: resolves to list[dict[str, Any]] """ @@ -1421,15 +1431,27 @@ class SQLBaseStore(object): self._simple_select_list_paginate_txn, table, keyvalues, - pagevalues, + orderby, + start, + limit, retcols, + order_direction=order_direction, ) @classmethod def _simple_select_list_paginate_txn( - cls, txn, table, keyvalues, pagevalues, retcols + cls, + txn, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction="ASC", ): - """Executes a SELECT query on the named table with start and limit, + """ + Executes a SELECT query on the named table with start and limit, of row numbers, which may return zero or number of rows from start to limit, returning the result as a list of dicts. @@ -1439,65 +1461,33 @@ class SQLBaseStore(object): keyvalues (dict[str, T] | None): column names and values to select the rows with, or None to not apply a WHERE clause. - pagevalues ([]): - order (str): order the select by this column - start (int): start number to begin the query from - limit (int): number of rows to reterive + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. retcols (iterable[str]): the names of the columns to return + order_direction (str): Whether the results should be ordered "ASC" or "DESC". Returns: defer.Deferred: resolves to list[dict[str, Any]] - """ + if order_direction not in ["ASC", "DESC"]: + raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") + if keyvalues: - sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - " ? ASC LIMIT ? OFFSET ?", - ) - txn.execute(sql, list(keyvalues.values()) + list(pagevalues)) + where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues) else: - sql = "SELECT %s FROM %s ORDER BY %s" % ( - ", ".join(retcols), - table, - " ? ASC LIMIT ? OFFSET ?", - ) - txn.execute(sql, pagevalues) + where_clause = "" + + sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( + ", ".join(retcols), + table, + where_clause, + orderby, + order_direction, + ) + txn.execute(sql, list(keyvalues.values()) + [limit, start]) return cls.cursor_to_dict(txn) - @defer.inlineCallbacks - def get_user_list_paginate( - self, table, keyvalues, pagevalues, retcols, desc="get_user_list_paginate" - ): - """Get a list of users from start row to a limit number of rows. This will - return a json object with users and total number of users in users list. - - Args: - table (str): the table name - keyvalues (dict[str, Any] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - pagevalues ([]): - order (str): order the select by this column - start (int): start number to begin the query from - limit (int): number of rows to reterive - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to json object {list[dict[str, Any]], count} - """ - users = yield self.runInteraction( - desc, - self._simple_select_list_paginate_txn, - table, - keyvalues, - pagevalues, - retcols, - ) - count = yield self.runInteraction(desc, self.get_user_count_txn) - retval = {"users": users, "total": count} - defer.returnValue(retval) - def get_user_count_txn(self, txn): """Get a total number of registered users in the users list.