split _get_events_from_db out of _enqueue_events

pull/5788/head
Richard van der Hoff 2019-07-24 16:37:50 +01:00
parent c9964ba600
commit e6a6c4fbab
1 changed files with 51 additions and 32 deletions

View File

@ -343,13 +343,12 @@ class EventsWorkerStore(SQLBaseStore):
log_ctx = LoggingContext.current_context() log_ctx = LoggingContext.current_context()
log_ctx.record_event_fetch(len(missing_events_ids)) log_ctx.record_event_fetch(len(missing_events_ids))
# Note that _enqueue_events is also responsible for turning db rows # Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if # into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event out # the events have been redacted, and if so pulling the redaction event out
# of the database to check it. # of the database to check it.
# #
# _enqueue_events is a bit of a rubbish name but naming is hard. missing_events = yield self._get_events_from_db(
missing_events = yield self._enqueue_events(
missing_events_ids, allow_rejected=allow_rejected missing_events_ids, allow_rejected=allow_rejected
) )
@ -458,43 +457,25 @@ class EventsWorkerStore(SQLBaseStore):
self.hs.get_reactor().callFromThread(fire, event_list, e) self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks @defer.inlineCallbacks
def _enqueue_events(self, events, allow_rejected=False): def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetches events from the database using the _event_fetch_list. This """Fetch a bunch of events from the database.
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events. Returned events will be added to the cache for future lookups.
Args: Args:
events (Iterable[str]): events to be fetched. event_ids (Iterable[str]): The event_ids of the events to fetch
allow_rejected (bool): Whether to include rejected events
Returns: Returns:
Deferred[Dict[str, _EventCacheEntry]]: map from event id to result. Deferred[Dict[str, _EventCacheEntry]]:
map from event id to result.
""" """
if not events: if not event_ids:
return {} return {}
events_d = defer.Deferred() row_map = yield self._enqueue_events(event_ids)
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
self._event_fetch_lock.notify() rows = (row_map.get(event_id) for event_id in event_ids)
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
should_start = True
else:
should_start = False
if should_start:
run_as_background_process(
"fetch_events", self.runWithConnection, self._do_fetch
)
logger.debug("Loading %d events", len(events))
with PreserveLoggingContext():
row_map = yield events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
rows = (row_map.get(event_id) for event_id in events)
# filter out absent rows # filter out absent rows
rows = filter(operator.truth, rows) rows = filter(operator.truth, rows)
@ -521,6 +502,44 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event.event_id: e for e in res if e} return {e.event.event_id: e for e in res if e}
@defer.inlineCallbacks
def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
Args:
events (Iterable[str]): events to be fetched.
Returns:
Deferred[Dict[str, Dict]]: map from event id to row data from the database.
May contain events that weren't requested.
"""
events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
self._event_fetch_lock.notify()
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
should_start = True
else:
should_start = False
if should_start:
run_as_background_process(
"fetch_events", self.runWithConnection, self._do_fetch
)
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
row_map = yield events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
return row_map
def _fetch_event_rows(self, txn, event_ids): def _fetch_event_rows(self, txn, event_ids):
"""Fetch event rows from the database """Fetch event rows from the database