Rename direction to step, apply checks consistently

pull/684/head
Mark Haines 2016-04-01 13:50:54 +01:00
parent e36bfbab38
commit a2866e2e6a
2 changed files with 16 additions and 16 deletions

View File

@ -97,7 +97,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "events", "stream_ordering" db_conn, "events", "stream_ordering"
) )
self._backfill_id_gen = StreamIdGenerator( self._backfill_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", direction=-1 db_conn, "events", "stream_ordering", step=-1
) )
self._receipts_id_gen = StreamIdGenerator( self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"

View File

@ -29,16 +29,16 @@ class IdGenerator(object):
return self._next_id return self._next_id
def _load_current_id(db_conn, table, column, direction=1): def _load_current_id(db_conn, table, column, step=1):
cur = db_conn.cursor() cur = db_conn.cursor()
if direction == 1: if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
else: else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
val, = cur.fetchone() val, = cur.fetchone()
cur.close() cur.close()
current_id = int(val) if val else direction current_id = int(val) if val else step
return (max if direction == 1 else min)(current_id, direction) return (max if step > 0 else min)(current_id, step)
class StreamIdGenerator(object): class StreamIdGenerator(object):
@ -58,21 +58,21 @@ class StreamIdGenerator(object):
:param list extra_tables: List of pairs of database tables and columns to :param list extra_tables: List of pairs of database tables and columns to
use to source the initial value of the generator from. The value with use to source the initial value of the generator from. The value with
the largest magnitude is used. the largest magnitude is used.
:param int direction: which direction the stream ids grow in. +1 to grow :param int step: which direction the stream ids grow in. +1 to grow
upwards, -1 to grow downwards. upwards, -1 to grow downwards.
Usage: Usage:
with stream_id_gen.get_next() as stream_id: with stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
def __init__(self, db_conn, table, column, extra_tables=[], direction=1): def __init__(self, db_conn, table, column, extra_tables=[], step=1):
self._lock = threading.Lock() self._lock = threading.Lock()
self._direction = direction self._step = step
self._current = _load_current_id(db_conn, table, column, direction) self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables: for table, column in extra_tables:
self._current = (max if direction > 0 else min)( self._current = (max if step > 0 else min)(
self._current, self._current,
_load_current_id(db_conn, table, column, direction) _load_current_id(db_conn, table, column, step)
) )
self._unfinished_ids = deque() self._unfinished_ids = deque()
@ -83,7 +83,7 @@ class StreamIdGenerator(object):
# ... persist event ... # ... persist event ...
""" """
with self._lock: with self._lock:
self._current += self._direction self._current += self._step
next_id = self._current next_id = self._current
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)
@ -106,9 +106,9 @@ class StreamIdGenerator(object):
""" """
with self._lock: with self._lock:
next_ids = range( next_ids = range(
self._current + self._direction, self._current + self._step,
self._current + self._direction * (n + 1), self._current + self._step * (n + 1),
self._direction self._step
) )
self._current += n self._current += n
@ -132,7 +132,7 @@ class StreamIdGenerator(object):
""" """
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:
return self._unfinished_ids[0] - self._direction return self._unfinished_ids[0] - self._step
return self._current return self._current