Remove race condition

pull/155/head
Erik Johnston 2015-05-14 16:54:35 +01:00
parent ef3d8754f5
commit 1d566edb81
4 changed files with 161 additions and 100 deletions

View File

@ -26,6 +26,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
from collections import namedtuple, OrderedDict
import contextlib
import functools
import sys
import time
@ -299,7 +301,7 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size)
self._event_fetch_lock = threading.Lock()
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
@ -342,6 +344,84 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
@contextlib.contextmanager
def _new_transaction(self, conn, desc, after_callbacks):
start = time.time() * 1000
txn_id = self._TXN_ID
# We don't really need these to be unique, so lets stop it from
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
name = "%s-%x" % (desc, txn_id, )
transaction_logger.debug("[TXN START] {%s}", name)
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks
)
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warn(
"[TXN OPERROR] {%s} %s %d/%d",
name, e, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
try:
yield txn
conn.commit()
return
except:
try:
conn.rollback()
except:
pass
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
duration = end - start
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc)
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
@ -353,75 +433,15 @@ class SQLBaseStore(object):
def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
start = time.time() * 1000
txn_id = self._TXN_ID
# We don't really need these to be unique, so lets stop it from
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
name = "%s-%x" % (desc, txn_id, )
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name)
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks
)
return func(txn, *args, **kwargs)
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warn(
"[TXN OPERROR] {%s} %s %d/%d",
name, e, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
duration = end - start
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc)
with self._new_transaction(conn, desc, after_callbacks) as txn:
return func(txn, *args, **kwargs)
result = yield preserve_context_over_fn(
self._db_pool.runWithConnection,
@ -432,6 +452,32 @@ class SQLBaseStore(object):
after_callback(*after_args)
defer.returnValue(result)
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
return func(conn, *args, **kwargs)
result = yield preserve_context_over_fn(
self._db_pool.runWithConnection,
inner_func, *args, **kwargs
)
defer.returnValue(result)
def cursor_to_dict(self, cursor):
"""Converts a SQL cursor into an list of dicts.

View File

@ -19,6 +19,8 @@ from ._base import IncorrectDatabaseSetup
class PostgresEngine(object):
single_threaded = False
def __init__(self, database_module):
self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE)

View File

@ -17,6 +17,8 @@ from synapse.storage import prepare_database, prepare_sqlite3_database
class Sqlite3Engine(object):
single_threaded = True
def __init__(self, database_module):
self.module = database_module

View File

@ -504,23 +504,26 @@ class EventsStore(SQLBaseStore):
if not events:
defer.returnValue({})
def do_fetch(txn):
def do_fetch(conn):
event_list = []
while True:
try:
with self._event_fetch_lock:
event_list = self._event_fetch_list
self._event_fetch_list = []
if not event_list:
i = 0
while not self._event_fetch_list:
self._event_fetch_ongoing -= 1
return
event_list = self._event_fetch_list
self._event_fetch_list = []
event_id_lists = zip(*event_list)[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
rows = self._fetch_event_rows(txn, event_ids)
with self._new_transaction(conn, "do_fetch", []) as txn:
rows = self._fetch_event_rows(txn, event_ids)
row_dict = {
r["event_id"]: r
@ -528,22 +531,44 @@ class EventsStore(SQLBaseStore):
}
for ids, d in event_list:
reactor.callFromThread(
d.callback,
[
row_dict[i] for i in ids
if i in row_dict
]
)
def fire():
if not d.called:
d.callback(
[
row_dict[i]
for i in ids
if i in row_dict
]
)
reactor.callFromThread(fire)
except Exception as e:
logger.exception("do_fetch")
for _, d in event_list:
try:
if not d.called:
reactor.callFromThread(d.errback, e)
except:
pass
def cb(rows):
return defer.gatherResults([
with self._event_fetch_lock:
self._event_fetch_ongoing -= 1
return
events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append(
(events, events_d)
)
self._event_fetch_lock.notify_all()
# if self._event_fetch_ongoing < 5:
self._event_fetch_ongoing += 1
self.runWithConnection(
do_fetch
)
rows = yield events_d
res = yield defer.gatherResults(
[
self._get_event_from_row(
None,
row["internal_metadata"], row["json"], row["redacts"],
@ -552,23 +577,9 @@ class EventsStore(SQLBaseStore):
rejected_reason=row["rejects"],
)
for row in rows
])
d = defer.Deferred()
d.addCallback(cb)
with self._event_fetch_lock:
self._event_fetch_list.append(
(events, d)
)
if self._event_fetch_ongoing < 3:
self._event_fetch_ongoing += 1
self.runInteraction(
"do_fetch",
do_fetch
)
res = yield d
],
consumeErrors=True
)
defer.returnValue({
e.event_id: e