diff --git a/changelog.d/5191.misc b/changelog.d/5191.misc new file mode 100644 index 0000000000..e0615fec9c --- /dev/null +++ b/changelog.d/5191.misc @@ -0,0 +1 @@ +Make generating SQL bounds for pagination generic. diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d105b6b17d..529ad4ea79 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -64,59 +64,135 @@ _EventDictReturn = namedtuple( ) -def lower_bound(token, engine, inclusive=False): - inclusive = "=" if inclusive else "" - if token.topological is None: - return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering") - else: - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x < a) OR (x=a AND y= (topological_ordering, stream_ordering) + AND (5, 3) < (topological_ordering, stream_ordering) + + would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). + + Note that tokens are considered to be after the row they are in, e.g. if + a row A has a token T, then we consider A to be before T. This convention + is important when figuring out inequalities for the generated SQL, and + produces the following result: + - If paginating forwards then we exclude any rows matching the from + token, but include those that match the to token. + - If paginating backwards then we include any rows matching the from + token, but include those that match the to token. + + Args: + direction (str): Whether we're paginating backwards("b") or + forwards ("f"). + column_names (tuple[str, str]): The column names to bound. Must *not* + be user defined as these get inserted directly into the SQL + statement without escapes. + from_token (tuple[int, int]|None): The start point for the pagination. + This is an exclusive minimum bound if direction is "f", and an + inclusive maximum bound if direction is "b". + to_token (tuple[int, int]|None): The endpoint point for the pagination. + This is an inclusive maximum bound if direction is "f", and an + exclusive minimum bound if direction is "b". + engine: The database engine to generate the clauses for + + Returns: + str: The sql expression + """ + assert direction in ("b", "f") + + where_clause = [] + if from_token: + where_clause.append( + _make_generic_sql_bound( + bound=">=" if direction == "b" else "<", + column_names=column_names, + values=from_token, + engine=engine, ) - return "(%d < %s OR (%d = %s AND %d <%s %s))" % ( - token.topological, - "topological_ordering", - token.topological, - "topological_ordering", - token.stream, - inclusive, - "stream_ordering", ) - -def upper_bound(token, engine, inclusive=True): - inclusive = "=" if inclusive else "" - if token.topological is None: - return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering") - else: - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well - # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we - # use the later form when running against postgres. - return "((%d,%d) >%s (%s,%s))" % ( - token.topological, - token.stream, - inclusive, - "topological_ordering", - "stream_ordering", + if to_token: + where_clause.append( + _make_generic_sql_bound( + bound="<" if direction == "b" else ">=", + column_names=column_names, + values=to_token, + engine=engine, ) - return "(%d > %s OR (%d = %s AND %d >%s %s))" % ( - token.topological, - "topological_ordering", - token.topological, - "topological_ordering", - token.stream, - inclusive, - "stream_ordering", ) + return " AND ".join(where_clause) + + +def _make_generic_sql_bound(bound, column_names, values, engine): + """Create an SQL expression that bounds the given column names by the + values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. + + Only works with two columns. + + Older versions of SQLite don't support that syntax so we have to expand it + out manually. + + Args: + bound (str): The comparison operator to use. One of ">", "<", ">=", + "<=", where the values are on the left and columns on the right. + names (tuple[str, str]): The column names. Must *not* be user defined + as these get inserted directly into the SQL statement without + escapes. + values (tuple[int|None, int]): The values to bound the columns by. If + the first value is None then only creates a bound on the second + column. + engine: The database engine to generate the SQL for + + Returns: + str + """ + + assert(bound in (">", "<", ">=", "<=")) + + name1, name2 = column_names + val1, val2 = values + + if val1 is None: + val2 = int(val2) + return "(%d %s %s)" % (val2, bound, name2) + + val1 = int(val1) + val2 = int(val2) + + if isinstance(engine, PostgresEngine): + # Postgres doesn't optimise ``(x < a) OR (x=a AND y