diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 4aa4e7ab15..656e57b5c6 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -19,7 +19,6 @@ from twisted.internet import defer, reactor from synapse.events import FrozenEvent from synapse.events.utils import prune_event -from synapse.util import unwrap_deferred from synapse.util.logcontext import preserve_context_over_deferred from synapse.util.logutils import log_function @@ -401,11 +400,7 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks def _get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False, txn=None): - """Gets a collection of events. If `txn` is not None the we use the - current transaction to fetch events and we return a deferred that is - guarenteed to have resolved. - """ + get_prev_content=False, allow_rejected=False): if not event_ids: defer.returnValue([]) @@ -424,21 +419,12 @@ class EventsStore(SQLBaseStore): if e_id in event_map and event_map[e_id] ]) - if not txn: - missing_events = yield self._enqueue_events( - missing_events_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - else: - missing_events = self._fetch_events_txn( - txn, - missing_events_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) + missing_events = yield self._enqueue_events( + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) event_map.update(missing_events) @@ -449,13 +435,38 @@ class EventsStore(SQLBaseStore): def _get_events_txn(self, txn, event_ids, check_redacted=True, get_prev_content=False, allow_rejected=False): - return unwrap_deferred(self._get_events( + if not event_ids: + return [] + + event_map = self._get_events_from_cache( event_ids, check_redacted=check_redacted, get_prev_content=get_prev_content, allow_rejected=allow_rejected, - txn=txn, - )) + ) + + missing_events_ids = [e for e in event_ids if e not in event_map] + + if not missing_events_ids: + return [ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ] + + missing_events = self._fetch_events_txn( + txn, + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + event_map.update(missing_events) + + return [ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ] def _invalidate_get_event_cache(self, event_id): for check_redacted in (False, True):