Wait for previous attempts at fetching keys for a given server before trying to fetch more

pull/194/head
Erik Johnston 2015-06-26 11:25:00 +01:00
parent b5f55a1d85
commit f0dd568e16
1 changed files with 68 additions and 15 deletions

View File

@ -27,6 +27,8 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
from OpenSSL import crypto
from collections import namedtuple
@ -88,6 +90,8 @@ class Keyring(object):
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
))
else:
deferreds[group_id] = defer.Deferred()
group = KeyGroup(server_name, group_id, key_ids)
@ -133,10 +137,41 @@ class Keyring(object):
Codes.UNAUTHORIZED,
)
deferreds.update(self.get_server_verify_keys(
group_id_to_group
))
server_to_deferred = {
server_name: defer.Deferred()
for server_name, _ in server_and_json
}
# We want to wait for any previous lookups to complete before
# proceeding.
wait_on_deferred = self.wait_for_previous_lookups(
[server_name for server_name, _ in server_and_json],
server_to_deferred,
)
# Actually start fetching keys.
wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
server_to_gids = {}
def remove_deferreds(res, server_name, group_id):
server_to_gids[server_name].discard(group_id)
if not server_to_gids[server_name]:
server_to_deferred.pop(server_name).callback(None)
return res
for g_id, deferred in deferreds.items():
server_name = group_id_to_group[g_id].server_name
server_to_gids.setdefault(server_name, set()).add(g_id)
deferred.addBoth(remove_deferreds, server_name, g_id)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
handle_key_deferred(
group_id_to_group[g_id],
@ -145,7 +180,30 @@ class Keyring(object):
for g_id in group_ids
]
def get_server_verify_keys(self, group_id_to_group):
@defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred):
"""Waits for any previous key lookups for the given servers to finish.
Args:
server_names (list): list of server_names we want to lookup
server_to_deferred (dict): server_name to deferred which gets
resolved once we've finished looking up keys for that server
"""
while True:
wait_on = [
self.key_downloads[server_name]
for server_name in server_names
if server_name in self.key_downloads
]
if wait_on:
yield defer.DeferredList(wait_on)
else:
break
for server_name, deferred in server_to_deferred:
self.key_downloads[server_name] = ObservableDeferred(deferred)
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
"""Takes a dict of KeyGroups and tries to find at least one key for
each group.
"""
@ -157,11 +215,6 @@ class Keyring(object):
self.get_keys_from_server, # Then try directly
)
group_deferreds = {
group_id: defer.Deferred()
for group_id in group_id_to_group
}
@defer.inlineCallbacks
def do_iterations():
merged_results = {}
@ -182,7 +235,7 @@ class Keyring(object):
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
group_deferreds.pop(group.group_id).callback((
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
key_id,
@ -205,7 +258,7 @@ class Keyring(object):
}
for group in missing_groups.values():
group_deferreds.pop(group.group_id).errback(SynapseError(
group_id_to_deferred[group.group_id].errback(SynapseError(
401,
"No key for %s with id %s" % (
group.server_name, group.key_ids,
@ -214,13 +267,13 @@ class Keyring(object):
))
def on_err(err):
for deferred in group_deferreds.values():
deferred.errback(err)
group_deferreds.clear()
for deferred in group_id_to_deferred.values():
if not deferred.called:
deferred.errback(err)
do_iterations().addErrback(on_err)
return group_deferreds
return group_id_to_deferred
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):