Only store event_auth for state events

pull/2247/head
Erik Johnston 2017-05-24 14:22:41 +01:00
parent 58c4720293
commit c049472b8a
3 changed files with 44 additions and 12 deletions

View File

@ -832,7 +832,11 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def on_event_auth(self, event_id):
auth = yield self.store.get_auth_chain([event_id])
event = yield self.store.get_event(event_id)
auth = yield self.store.get_auth_chain(
[auth_id for auth_id, _ in event.auth_events],
include_given=True
)
for event in auth:
event.signatures.update(
@ -1047,9 +1051,7 @@ class FederationHandler(BaseHandler):
yield user_joined_room(self.distributor, user, event.room_id)
state_ids = context.prev_state_ids.values()
auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids
))
auth_chain = yield self.store.get_auth_chain(state_ids)
state = yield self.store.get_events(context.prev_state_ids.values())
@ -1598,7 +1600,11 @@ class FederationHandler(BaseHandler):
pass
# Now get the current auth_chain for the event.
local_auth_chain = yield self.store.get_auth_chain([event_id])
event = yield self.store.get_event(event_id)
local_auth_chain = yield self.store.get_auth_chain(
[auth_id for auth_id, _ in event.auth_events],
include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell
# everyone.
@ -1791,7 +1797,9 @@ class FederationHandler(BaseHandler):
auth_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids
)
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
local_auth_chain = yield self.store.get_auth_chain(
auth_ids, include_given=True
)
try:
# 2. Get remote difference.

View File

@ -44,18 +44,41 @@ class EventFederationStore(SQLBaseStore):
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
def get_auth_chain(self, event_ids):
return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
def get_auth_chain_ids(self, event_ids):
Args:
event_ids (list): state events
include_given (bool): include the given events in result
Returns:
list of events
"""
return self.get_auth_chain_ids(
event_ids, include_given=include_given,
).addCallback(self._get_events)
def get_auth_chain_ids(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
event_ids (list): state events
include_given (bool): include the given events in result
Returns:
list of event_ids
"""
return self.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
event_ids
event_ids, include_given
)
def _get_auth_chain_ids_txn(self, txn, event_ids):
results = set()
def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
if include_given:
results = set(event_ids)
else:
results = set()
base_sql = (
"SELECT auth_id FROM event_auth WHERE event_id IN (%s)"

View File

@ -1120,6 +1120,7 @@ class EventsStore(SQLBaseStore):
}
for event, _ in events_and_contexts
for auth_id, _ in event.auth_events
if event.is_state()
],
)