Split out get_events and co into a worker store

pull/2903/head
Erik Johnston 2018-02-21 11:41:48 +00:00
parent a2b25de68d
commit 27b094f382
2 changed files with 352 additions and 356 deletions

View File

@ -18,6 +18,7 @@ from synapse.api.constants import EventTypes
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.events import EventsWorkerStore
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberStore
from synapse.storage.state import StateGroupWorkerStore from synapse.storage.state import StateGroupWorkerStore
from synapse.storage.stream import StreamStore from synapse.storage.stream import StreamStore
@ -38,7 +39,7 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class. # the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore): class SlavedEventStore(EventsWorkerStore, StateGroupWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs) super(SlavedEventStore, self).__init__(db_conn, hs)
@ -104,8 +105,6 @@ class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
get_push_action_users_in_range = ( get_push_action_users_in_range = (
DataStore.get_push_action_users_in_range.__func__ DataStore.get_push_action_users_in_range.__func__
) )
get_event = DataStore.get_event.__func__
get_events = DataStore.get_events.__func__
get_rooms_for_user_where_membership_is = ( get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__ DataStore.get_rooms_for_user_where_membership_is.__func__
) )
@ -135,14 +134,6 @@ class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
_set_before_and_after = staticmethod(DataStore._set_before_and_after) _set_before_and_after = staticmethod(DataStore._set_before_and_after)
_get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__
_invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
_enqueue_events = DataStore._enqueue_events.__func__
_do_fetch = DataStore._do_fetch.__func__
_fetch_event_rows = DataStore._fetch_event_rows.__func__
_get_event_from_row = DataStore._get_event_from_row.__func__
_get_rooms_for_user_where_membership_is_txn = ( _get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__ DataStore._get_rooms_for_user_where_membership_is_txn.__func__
) )

View File

@ -199,7 +199,356 @@ def _retry_on_integrity_error(func):
return f return f
class EventsStore(SQLBaseStore): class EventsWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(EventsWorkerStore, self).__init__(db_conn, hs)
self._event_persist_queue = _EventPeristenceQueue()
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False,
allow_none=False):
"""Get an event from the database by event_id.
Args:
event_id (str): The event_id of the event to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if
False throw an exception.
Returns:
Deferred : A FrozenEvent.
"""
events = yield self._get_events(
[event_id],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
if not events and not allow_none:
raise SynapseError(404, "Could not find event %s" % (event_id,))
defer.returnValue(events[0] if events else None)
@defer.inlineCallbacks
def get_events(self, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
"""Get events from the database
Args:
event_ids (list): The event_ids of the events to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
Returns:
Deferred : Dict from event_id to event.
"""
events = yield self._get_events(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
defer.returnValue({e.event_id: e for e in events})
@defer.inlineCallbacks
def _get_events(self, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not event_ids:
defer.returnValue([])
event_id_list = event_ids
event_ids = set(event_ids)
event_entry_map = self._get_events_from_cache(
event_ids,
allow_rejected=allow_rejected,
)
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
if missing_events_ids:
missing_events = yield self._enqueue_events(
missing_events_ids,
check_redacted=check_redacted,
allow_rejected=allow_rejected,
)
event_entry_map.update(missing_events)
events = []
for event_id in event_id_list:
entry = event_entry_map.get(event_id, None)
if not entry:
continue
if allow_rejected or not entry.event.rejected_reason:
if check_redacted and entry.redacted_event:
event = entry.redacted_event
else:
event = entry.event
events.append(event)
if get_prev_content:
if "replaces_state" in event.unsigned:
prev = yield self.get_event(
event.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
)
if prev:
event.unsigned = dict(event.unsigned)
event.unsigned["prev_content"] = prev.content
event.unsigned["prev_sender"] = prev.sender
defer.returnValue(events)
def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
"""Fetch events from the caches
Args:
events (list(str)): list of event_ids to fetch
allow_rejected (bool): Whether to teturn events that were rejected
update_metrics (bool): Whether to update the cache hit ratio metrics
Returns:
dict of event_id -> _EventCacheEntry for each event_id in cache. If
allow_rejected is `False` then there will still be an entry but it
will be `None`
"""
event_map = {}
for event_id in events:
ret = self._get_event_cache.get(
(event_id,), None,
update_metrics=update_metrics,
)
if not ret:
continue
if allow_rejected or not ret.event.rejected_reason:
event_map[event_id] = ret
else:
event_map[event_id] = None
return event_map
def _do_fetch(self, conn):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
event_list = []
i = 0
while True:
try:
with self._event_fetch_lock:
event_list = self._event_fetch_list
self._event_fetch_list = []
if not event_list:
single_threaded = self.database_engine.single_threaded
if single_threaded or i > EVENT_QUEUE_ITERATIONS:
self._event_fetch_ongoing -= 1
return
else:
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
i += 1
continue
i = 0
event_id_lists = zip(*event_list)[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
rows = self._new_transaction(
conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
)
row_dict = {
r["event_id"]: r
for r in rows
}
# We only want to resolve deferreds from the main thread
def fire(lst, res):
for ids, d in lst:
if not d.called:
try:
with PreserveLoggingContext():
d.callback([
res[i]
for i in ids
if i in res
])
except Exception:
logger.exception("Failed to callback")
with PreserveLoggingContext():
reactor.callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
def fire(evs):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
d.errback(e)
if event_list:
with PreserveLoggingContext():
reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
"""
if not events:
defer.returnValue({})
events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append(
(events, events_d)
)
self._event_fetch_lock.notify()
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
should_start = True
else:
should_start = False
if should_start:
with PreserveLoggingContext():
self.runWithConnection(
self._do_fetch
)
logger.debug("Loading %d events", len(events))
with PreserveLoggingContext():
rows = yield events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = yield make_deferred_yieldable(defer.gatherResults(
[
preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"],
rejected_reason=row["rejects"],
)
for row in rows
],
consumeErrors=True
))
defer.returnValue({
e.event.event_id: e
for e in res if e
})
def _fetch_event_rows(self, txn, events):
rows = []
N = 200
for i in range(1 + len(events) / N):
evs = events[i * N:(i + 1) * N]
if not evs:
break
sql = (
"SELECT "
" e.event_id as event_id, "
" e.internal_metadata,"
" e.json,"
" r.redacts as redacts,"
" rej.event_id as rejects "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"] * len(evs)),)
txn.execute(sql, evs)
rows.extend(self.cursor_to_dict(txn))
return rows
@defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted,
rejected_reason=None):
with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
if rejected_reason:
rejected_reason = yield self._simple_select_one_onecol(
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
desc="_get_event_from_row_rejected_reason",
)
original_ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
redacted_event = None
if redacted:
redacted_event = prune_event(original_ev)
redaction_id = yield self._simple_select_one_onecol(
table="redactions",
keyvalues={"redacts": redacted_event.event_id},
retcol="event_id",
desc="_get_event_from_row_redactions",
)
redacted_event.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = yield self.get_event(
redaction_id,
check_redacted=False,
allow_none=True,
)
if because:
# It's fine to do add the event directly, since get_pdu_json
# will serialise this field correctly
redacted_event.unsigned["redacted_because"] = because
cache_entry = _EventCacheEntry(
event=original_ev,
redacted_event=redacted_event,
)
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
defer.returnValue(cache_entry)
class EventsStore(EventsWorkerStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
@ -234,8 +583,6 @@ class EventsStore(SQLBaseStore):
psql_only=True, psql_only=True,
) )
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
def persist_events(self, events_and_contexts, backfilled=False): def persist_events(self, events_and_contexts, backfilled=False):
@ -609,62 +956,6 @@ class EventsStore(SQLBaseStore):
defer.returnValue((to_delete, to_insert)) defer.returnValue((to_delete, to_insert))
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False,
allow_none=False):
"""Get an event from the database by event_id.
Args:
event_id (str): The event_id of the event to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if
False throw an exception.
Returns:
Deferred : A FrozenEvent.
"""
events = yield self._get_events(
[event_id],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
if not events and not allow_none:
raise SynapseError(404, "Could not find event %s" % (event_id,))
defer.returnValue(events[0] if events else None)
@defer.inlineCallbacks
def get_events(self, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
"""Get events from the database
Args:
event_ids (list): The event_ids of the events to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
Returns:
Deferred : Dict from event_id to event.
"""
events = yield self._get_events(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
defer.returnValue({e.event_id: e for e in events})
@log_function @log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled, def _persist_events_txn(self, txn, events_and_contexts, backfilled,
delete_existing=False, state_delta_for_room={}, delete_existing=False, state_delta_for_room={},
@ -1375,292 +1666,6 @@ class EventsStore(SQLBaseStore):
"have_events", f, "have_events", f,
) )
@defer.inlineCallbacks
def _get_events(self, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not event_ids:
defer.returnValue([])
event_id_list = event_ids
event_ids = set(event_ids)
event_entry_map = self._get_events_from_cache(
event_ids,
allow_rejected=allow_rejected,
)
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
if missing_events_ids:
missing_events = yield self._enqueue_events(
missing_events_ids,
check_redacted=check_redacted,
allow_rejected=allow_rejected,
)
event_entry_map.update(missing_events)
events = []
for event_id in event_id_list:
entry = event_entry_map.get(event_id, None)
if not entry:
continue
if allow_rejected or not entry.event.rejected_reason:
if check_redacted and entry.redacted_event:
event = entry.redacted_event
else:
event = entry.event
events.append(event)
if get_prev_content:
if "replaces_state" in event.unsigned:
prev = yield self.get_event(
event.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
)
if prev:
event.unsigned = dict(event.unsigned)
event.unsigned["prev_content"] = prev.content
event.unsigned["prev_sender"] = prev.sender
defer.returnValue(events)
def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
"""Fetch events from the caches
Args:
events (list(str)): list of event_ids to fetch
allow_rejected (bool): Whether to teturn events that were rejected
update_metrics (bool): Whether to update the cache hit ratio metrics
Returns:
dict of event_id -> _EventCacheEntry for each event_id in cache. If
allow_rejected is `False` then there will still be an entry but it
will be `None`
"""
event_map = {}
for event_id in events:
ret = self._get_event_cache.get(
(event_id,), None,
update_metrics=update_metrics,
)
if not ret:
continue
if allow_rejected or not ret.event.rejected_reason:
event_map[event_id] = ret
else:
event_map[event_id] = None
return event_map
def _do_fetch(self, conn):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
event_list = []
i = 0
while True:
try:
with self._event_fetch_lock:
event_list = self._event_fetch_list
self._event_fetch_list = []
if not event_list:
single_threaded = self.database_engine.single_threaded
if single_threaded or i > EVENT_QUEUE_ITERATIONS:
self._event_fetch_ongoing -= 1
return
else:
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
i += 1
continue
i = 0
event_id_lists = zip(*event_list)[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
rows = self._new_transaction(
conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
)
row_dict = {
r["event_id"]: r
for r in rows
}
# We only want to resolve deferreds from the main thread
def fire(lst, res):
for ids, d in lst:
if not d.called:
try:
with PreserveLoggingContext():
d.callback([
res[i]
for i in ids
if i in res
])
except Exception:
logger.exception("Failed to callback")
with PreserveLoggingContext():
reactor.callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
def fire(evs):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
d.errback(e)
if event_list:
with PreserveLoggingContext():
reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
"""
if not events:
defer.returnValue({})
events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append(
(events, events_d)
)
self._event_fetch_lock.notify()
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
should_start = True
else:
should_start = False
if should_start:
with PreserveLoggingContext():
self.runWithConnection(
self._do_fetch
)
logger.debug("Loading %d events", len(events))
with PreserveLoggingContext():
rows = yield events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = yield make_deferred_yieldable(defer.gatherResults(
[
preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"],
rejected_reason=row["rejects"],
)
for row in rows
],
consumeErrors=True
))
defer.returnValue({
e.event.event_id: e
for e in res if e
})
def _fetch_event_rows(self, txn, events):
rows = []
N = 200
for i in range(1 + len(events) / N):
evs = events[i * N:(i + 1) * N]
if not evs:
break
sql = (
"SELECT "
" e.event_id as event_id, "
" e.internal_metadata,"
" e.json,"
" r.redacts as redacts,"
" rej.event_id as rejects "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"] * len(evs)),)
txn.execute(sql, evs)
rows.extend(self.cursor_to_dict(txn))
return rows
@defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted,
rejected_reason=None):
with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
if rejected_reason:
rejected_reason = yield self._simple_select_one_onecol(
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
desc="_get_event_from_row_rejected_reason",
)
original_ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
redacted_event = None
if redacted:
redacted_event = prune_event(original_ev)
redaction_id = yield self._simple_select_one_onecol(
table="redactions",
keyvalues={"redacts": redacted_event.event_id},
retcol="event_id",
desc="_get_event_from_row_redactions",
)
redacted_event.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = yield self.get_event(
redaction_id,
check_redacted=False,
allow_none=True,
)
if because:
# It's fine to do add the event directly, since get_pdu_json
# will serialise this field correctly
redacted_event.unsigned["redacted_because"] = because
cache_entry = _EventCacheEntry(
event=original_ev,
redacted_event=redacted_event,
)
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
defer.returnValue(cache_entry)
@defer.inlineCallbacks @defer.inlineCallbacks
def count_daily_messages(self): def count_daily_messages(self):
""" """