Verify state and auth_chain in the same batch

erikj/persist_event_perf
Erik Johnston 2015-06-24 14:51:10 +01:00
parent 74f7b44955
commit 0f2ac80305
2 changed files with 51 additions and 44 deletions

View File

@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
class FederationBase(object): class FederationBase(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False): def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
include_none=False):
"""Takes a list of PDUs and checks the signatures and hashs of each """Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of the database and if not then request if from the originating server of
@ -56,51 +57,60 @@ class FederationBase(object):
deferreds = self._check_sigs_and_hashes(pdus) deferreds = self._check_sigs_and_hashes(pdus)
def callback(pdu): def callback(pdu):
signed_pdus.append(pdu) return pdu
def errback(failure, pdu): def errback(failure, pdu):
failure.trap(SynapseError) failure.trap(SynapseError)
return None
def try_local_db(res, pdu):
if not res:
# Check local db. # Check local db.
new_pdu = yield self.store.get_event( return self.store.get_event(
pdu.event_id, pdu.event_id,
allow_rejected=True, allow_rejected=True,
allow_none=True, allow_none=True,
) )
if new_pdu: return res
signed_pdus.append(new_pdu)
return
# Check pdu.origin def try_remote(res, pdu):
if pdu.origin != origin: if not res and pdu.origin != origin:
try: return self.get_pdu(
new_pdu = yield self.get_pdu(
destinations=[pdu.origin], destinations=[pdu.origin],
event_id=pdu.event_id, event_id=pdu.event_id,
outlier=outlier, outlier=outlier,
timeout=10000, timeout=10000,
) ).addErrback(lambda e: None)
return res
if new_pdu:
signed_pdus.append(new_pdu)
return
except:
pass
def warn(res, pdu):
if not res:
logger.warn( logger.warn(
"Failed to find copy of %s with valid signature", "Failed to find copy of %s with valid signature",
pdu.event_id, pdu.event_id,
) )
return res
for pdu, deferred in zip(pdus, deferreds): for pdu, deferred in zip(pdus, deferreds):
deferred.addCallbacks(callback, errback, errbackArgs=[pdu]) deferred.addCallbacks(
callback, errback, errbackArgs=[pdu]
).addCallback(
try_local_db, pdu
).addCallback(
try_remote, pdu
).addCallback(
warn, pdu
)
yield defer.gatherResults( valid_pdus = yield defer.gatherResults(
deferreds, deferreds,
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
defer.returnValue(signed_pdus) if include_none:
defer.returnValue(valid_pdus)
else:
defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, pdu): def _check_sigs_and_hash(self, pdu):
return self._check_sigs_and_hashes([pdu])[0] return self._check_sigs_and_hashes([pdu])[0]

View File

@ -380,17 +380,14 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", []) for p in content.get("auth_chain", [])
] ]
signed_state, signed_auth = yield defer.gatherResults( valid_pdus = yield self._check_sigs_and_hash_and_fetch(
[ destination, state + auth_chain,
self._check_sigs_and_hash_and_fetch( outlier=True,
destination, state, outlier=True include_none=True,
),
self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
) )
],
consumeErrors=True signed_state = [p for p in valid_pdus[:len(state)] if p]
).addErrback(unwrapFirstError) signed_auth = [p for p in valid_pdus[len(state):] if p]
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)