Merge pull request #3117 from matrix-org/rav/refactor_have_events

Refactor store.have_events
pull/3125/head
Richard van der Hoff 2018-04-20 10:26:12 +01:00 committed by GitHub
commit bc381d5798
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 27 deletions

View File

@ -394,7 +394,7 @@ class FederationClient(FederationBase):
seen_events = yield self.store.get_events(event_ids, allow_rejected=True) seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = seen_events.values() signed_events = seen_events.values()
else: else:
seen_events = yield self.store.have_events(event_ids) seen_events = yield self.store.have_seen_events(event_ids)
signed_events = [] signed_events = []
failed_to_fetch = set() failed_to_fetch = set()

View File

@ -149,10 +149,6 @@ class FederationHandler(BaseHandler):
auth_chain = [] auth_chain = []
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
fetch_state = False fetch_state = False
# Get missing pdus if necessary. # Get missing pdus if necessary.
@ -168,7 +164,7 @@ class FederationHandler(BaseHandler):
) )
prevs = {e_id for e_id, _ in pdu.prev_events} prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys()) seen = yield self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth: if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this # This is so that we don't notify the user about this
@ -196,8 +192,7 @@ class FederationHandler(BaseHandler):
# Update the set of things we've seen after trying to # Update the set of things we've seen after trying to
# fetch the missing stuff # fetch the missing stuff
have_seen = yield self.store.have_events(prevs) seen = yield self.store.have_seen_events(prevs)
seen = set(have_seen.iterkeys())
if not prevs - seen: if not prevs - seen:
logger.info( logger.info(
@ -248,8 +243,7 @@ class FederationHandler(BaseHandler):
min_depth (int): Minimum depth of events to return. min_depth (int): Minimum depth of events to return.
""" """
# We recalculate seen, since it may have changed. # We recalculate seen, since it may have changed.
have_seen = yield self.store.have_events(prevs) seen = yield self.store.have_seen_events(prevs)
seen = set(have_seen.keys())
if not prevs - seen: if not prevs - seen:
return return
@ -361,9 +355,7 @@ class FederationHandler(BaseHandler):
if auth_chain: if auth_chain:
event_ids |= {e.event_id for e in auth_chain} event_ids |= {e.event_id for e in auth_chain}
seen_ids = set( seen_ids = yield self.store.have_seen_events(event_ids)
(yield self.store.have_events(event_ids)).keys()
)
if state and auth_chain is not None: if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication # If we have any state or auth_chain given to us by the replication
@ -633,7 +625,7 @@ class FederationHandler(BaseHandler):
failed_to_fetch = missing_auth - set(auth_events) failed_to_fetch = missing_auth - set(auth_events)
seen_events = yield self.store.have_events( seen_events = yield self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys()) set(auth_events.keys()) | set(state_events.keys())
) )
@ -1736,7 +1728,8 @@ class FederationHandler(BaseHandler):
event_key = None event_key = None
if event_auth_events - current_state: if event_auth_events - current_state:
have_events = yield self.store.have_events( # TODO: can we use store.have_seen_events here instead?
have_events = yield self.store.get_seen_events_with_rejections(
event_auth_events - current_state event_auth_events - current_state
) )
else: else:
@ -1759,12 +1752,12 @@ class FederationHandler(BaseHandler):
origin, event.room_id, event.event_id origin, event.room_id, event.event_id
) )
seen_remotes = yield self.store.have_events( seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in remote_auth_chain] [e.event_id for e in remote_auth_chain]
) )
for e in remote_auth_chain: for e in remote_auth_chain:
if e.event_id in seen_remotes.keys(): if e.event_id in seen_remotes:
continue continue
if e.event_id == event.event_id: if e.event_id == event.event_id:
@ -1791,7 +1784,7 @@ class FederationHandler(BaseHandler):
except AuthError: except AuthError:
pass pass
have_events = yield self.store.have_events( have_events = yield self.store.get_seen_events_with_rejections(
[e_id for e_id, _ in event.auth_events] [e_id for e_id, _ in event.auth_events]
) )
seen_events = set(have_events.keys()) seen_events = set(have_events.keys())
@ -1876,13 +1869,13 @@ class FederationHandler(BaseHandler):
local_auth_chain, local_auth_chain,
) )
seen_remotes = yield self.store.have_events( seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in result["auth_chain"]] [e.event_id for e in result["auth_chain"]]
) )
# 3. Process any remote auth chain events we haven't seen. # 3. Process any remote auth chain events we haven't seen.
for ev in result["auth_chain"]: for ev in result["auth_chain"]:
if ev.event_id in seen_remotes.keys(): if ev.event_id in seen_remotes:
continue continue
if ev.event_id == event.event_id: if ev.event_id == event.event_id:

View File

@ -16,6 +16,7 @@
from collections import OrderedDict, deque, namedtuple from collections import OrderedDict, deque, namedtuple
from functools import wraps from functools import wraps
import itertools
import logging import logging
import simplejson as json import simplejson as json
@ -1320,13 +1321,49 @@ class EventsStore(EventsWorkerStore):
defer.returnValue(set(r["event_id"] for r in rows)) defer.returnValue(set(r["event_id"] for r in rows))
def have_events(self, event_ids): @defer.inlineCallbacks
def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them. """Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns: Returns:
dict: Has an entry for each event id we already have seen. Maps to Deferred[set[str]]: The events we have already seen.
the rejected reason string if we rejected the event, else maps to """
None. results = set()
def have_seen_events_txn(txn, chunk):
sql = (
"SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
% (",".join("?" * len(chunk)), )
)
txn.execute(sql, chunk)
for (event_id, ) in txn:
results.add(event_id)
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
[]):
yield self.runInteraction(
"have_seen_events",
have_seen_events_txn,
chunk,
)
defer.returnValue(results)
def get_seen_events_with_rejections(self, event_ids):
"""Given a list of event ids, check if we rejected them.
Args:
event_ids (list[str])
Returns:
Deferred[dict[str, str|None):
Has an entry for each event id we already have seen. Maps to
the rejected reason string if we rejected the event, else maps
to None.
""" """
if not event_ids: if not event_ids:
return defer.succeed({}) return defer.succeed({})
@ -1348,9 +1385,7 @@ class EventsStore(EventsWorkerStore):
return res return res
return self.runInteraction( return self.runInteraction("get_rejection_reasons", f)
"have_events", f,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def count_daily_messages(self): def count_daily_messages(self):