Verify state and auth_chain in the same batch
parent
74f7b44955
commit
0f2ac80305
|
@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class FederationBase(object):
|
class FederationBase(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
|
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
||||||
|
include_none=False):
|
||||||
"""Takes a list of PDUs and checks the signatures and hashs of each
|
"""Takes a list of PDUs and checks the signatures and hashs of each
|
||||||
one. If a PDU fails its signature check then we check if we have it in
|
one. If a PDU fails its signature check then we check if we have it in
|
||||||
the database and if not then request if from the originating server of
|
the database and if not then request if from the originating server of
|
||||||
|
@ -56,51 +57,60 @@ class FederationBase(object):
|
||||||
deferreds = self._check_sigs_and_hashes(pdus)
|
deferreds = self._check_sigs_and_hashes(pdus)
|
||||||
|
|
||||||
def callback(pdu):
|
def callback(pdu):
|
||||||
signed_pdus.append(pdu)
|
return pdu
|
||||||
|
|
||||||
def errback(failure, pdu):
|
def errback(failure, pdu):
|
||||||
failure.trap(SynapseError)
|
failure.trap(SynapseError)
|
||||||
|
return None
|
||||||
|
|
||||||
# Check local db.
|
def try_local_db(res, pdu):
|
||||||
new_pdu = yield self.store.get_event(
|
if not res:
|
||||||
pdu.event_id,
|
# Check local db.
|
||||||
allow_rejected=True,
|
return self.store.get_event(
|
||||||
allow_none=True,
|
pdu.event_id,
|
||||||
)
|
allow_rejected=True,
|
||||||
if new_pdu:
|
allow_none=True,
|
||||||
signed_pdus.append(new_pdu)
|
)
|
||||||
return
|
return res
|
||||||
|
|
||||||
# Check pdu.origin
|
def try_remote(res, pdu):
|
||||||
if pdu.origin != origin:
|
if not res and pdu.origin != origin:
|
||||||
try:
|
return self.get_pdu(
|
||||||
new_pdu = yield self.get_pdu(
|
destinations=[pdu.origin],
|
||||||
destinations=[pdu.origin],
|
event_id=pdu.event_id,
|
||||||
event_id=pdu.event_id,
|
outlier=outlier,
|
||||||
outlier=outlier,
|
timeout=10000,
|
||||||
timeout=10000,
|
).addErrback(lambda e: None)
|
||||||
)
|
return res
|
||||||
|
|
||||||
if new_pdu:
|
def warn(res, pdu):
|
||||||
signed_pdus.append(new_pdu)
|
if not res:
|
||||||
return
|
logger.warn(
|
||||||
except:
|
"Failed to find copy of %s with valid signature",
|
||||||
pass
|
pdu.event_id,
|
||||||
|
)
|
||||||
logger.warn(
|
return res
|
||||||
"Failed to find copy of %s with valid signature",
|
|
||||||
pdu.event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
for pdu, deferred in zip(pdus, deferreds):
|
for pdu, deferred in zip(pdus, deferreds):
|
||||||
deferred.addCallbacks(callback, errback, errbackArgs=[pdu])
|
deferred.addCallbacks(
|
||||||
|
callback, errback, errbackArgs=[pdu]
|
||||||
|
).addCallback(
|
||||||
|
try_local_db, pdu
|
||||||
|
).addCallback(
|
||||||
|
try_remote, pdu
|
||||||
|
).addCallback(
|
||||||
|
warn, pdu
|
||||||
|
)
|
||||||
|
|
||||||
yield defer.gatherResults(
|
valid_pdus = yield defer.gatherResults(
|
||||||
deferreds,
|
deferreds,
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(signed_pdus)
|
if include_none:
|
||||||
|
defer.returnValue(valid_pdus)
|
||||||
|
else:
|
||||||
|
defer.returnValue([p for p in valid_pdus if p])
|
||||||
|
|
||||||
def _check_sigs_and_hash(self, pdu):
|
def _check_sigs_and_hash(self, pdu):
|
||||||
return self._check_sigs_and_hashes([pdu])[0]
|
return self._check_sigs_and_hashes([pdu])[0]
|
||||||
|
|
|
@ -380,17 +380,14 @@ class FederationClient(FederationBase):
|
||||||
for p in content.get("auth_chain", [])
|
for p in content.get("auth_chain", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
signed_state, signed_auth = yield defer.gatherResults(
|
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||||
[
|
destination, state + auth_chain,
|
||||||
self._check_sigs_and_hash_and_fetch(
|
outlier=True,
|
||||||
destination, state, outlier=True
|
include_none=True,
|
||||||
),
|
)
|
||||||
self._check_sigs_and_hash_and_fetch(
|
|
||||||
destination, auth_chain, outlier=True
|
signed_state = [p for p in valid_pdus[:len(state)] if p]
|
||||||
)
|
signed_auth = [p for p in valid_pdus[len(state):] if p]
|
||||||
],
|
|
||||||
consumeErrors=True
|
|
||||||
).addErrback(unwrapFirstError)
|
|
||||||
|
|
||||||
auth_chain.sort(key=lambda e: e.depth)
|
auth_chain.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue