Refactor Keyring._start_key_lookups
There's an awful lot of deferreds and dictionaries flying around here. The whole thing can be made much simpler and achieve the same effect.pull/5724/head
parent
356ed0438e
commit
c7095be913
|
@ -238,27 +238,9 @@ class Keyring(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# create a deferred for each server we're going to look up the keys
|
ctx = LoggingContext.current_context()
|
||||||
# for; we'll resolve them once we have completed our lookups.
|
|
||||||
# These will be passed into wait_for_previous_lookups to block
|
|
||||||
# any other lookups until we have finished.
|
|
||||||
# The deferreds are called with no logcontext.
|
|
||||||
server_to_deferred = {
|
|
||||||
rq.server_name: defer.Deferred() for rq in verify_requests
|
|
||||||
}
|
|
||||||
|
|
||||||
# We want to wait for any previous lookups to complete before
|
# map from server name to a set of outstanding request ids
|
||||||
# proceeding.
|
|
||||||
yield self.wait_for_previous_lookups(server_to_deferred)
|
|
||||||
|
|
||||||
# Actually start fetching keys.
|
|
||||||
self._get_server_verify_keys(verify_requests)
|
|
||||||
|
|
||||||
# When we've finished fetching all the keys for a given server_name,
|
|
||||||
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
|
||||||
# any lookups waiting will proceed.
|
|
||||||
#
|
|
||||||
# map from server name to a set of request ids
|
|
||||||
server_to_request_ids = {}
|
server_to_request_ids = {}
|
||||||
|
|
||||||
for verify_request in verify_requests:
|
for verify_request in verify_requests:
|
||||||
|
@ -266,40 +248,55 @@ class Keyring(object):
|
||||||
request_id = id(verify_request)
|
request_id = id(verify_request)
|
||||||
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||||
|
|
||||||
def remove_deferreds(res, verify_request):
|
# Wait for any previous lookups to complete before proceeding.
|
||||||
|
yield self.wait_for_previous_lookups(server_to_request_ids.keys())
|
||||||
|
|
||||||
|
# take out a lock on each of the servers by sticking a Deferred in
|
||||||
|
# key_downloads
|
||||||
|
for server_name in server_to_request_ids.keys():
|
||||||
|
self.key_downloads[server_name] = defer.Deferred()
|
||||||
|
logger.debug("Got key lookup lock on %s", server_name)
|
||||||
|
|
||||||
|
# When we've finished fetching all the keys for a given server_name,
|
||||||
|
# drop the lock by resolving the deferred in key_downloads.
|
||||||
|
def lookup_done(res, verify_request):
|
||||||
server_name = verify_request.server_name
|
server_name = verify_request.server_name
|
||||||
request_id = id(verify_request)
|
server_requests = server_to_request_ids[server_name]
|
||||||
server_to_request_ids[server_name].discard(request_id)
|
server_requests.remove(id(verify_request))
|
||||||
if not server_to_request_ids[server_name]:
|
|
||||||
d = server_to_deferred.pop(server_name, None)
|
# if there are no more requests for this server, we can drop the lock.
|
||||||
if d:
|
if not server_requests:
|
||||||
d.callback(None)
|
with PreserveLoggingContext(ctx):
|
||||||
|
logger.debug("Releasing key lookup lock on %s", server_name)
|
||||||
|
|
||||||
|
d = self.key_downloads.pop(server_name)
|
||||||
|
d.callback(None)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
for verify_request in verify_requests:
|
for verify_request in verify_requests:
|
||||||
verify_request.key_ready.addBoth(remove_deferreds, verify_request)
|
verify_request.key_ready.addBoth(lookup_done, verify_request)
|
||||||
|
|
||||||
|
# Actually start fetching keys.
|
||||||
|
self._get_server_verify_keys(verify_requests)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error starting key lookups")
|
logger.exception("Error starting key lookups")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wait_for_previous_lookups(self, server_to_deferred):
|
def wait_for_previous_lookups(self, server_names):
|
||||||
"""Waits for any previous key lookups for the given servers to finish.
|
"""Waits for any previous key lookups for the given servers to finish.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
|
server_names (Iterable[str]): list of servers which we want to look up
|
||||||
resolved once we've finished looking up keys for that server.
|
|
||||||
The Deferreds should be regular twisted ones which call their
|
|
||||||
callbacks with no logcontext.
|
|
||||||
|
|
||||||
Returns: a Deferred which resolves once all key lookups for the given
|
Returns:
|
||||||
servers have completed. Follows the synapse rules of logcontext
|
Deferred[None]: resolves once all key lookups for the given servers have
|
||||||
preservation.
|
completed. Follows the synapse rules of logcontext preservation.
|
||||||
"""
|
"""
|
||||||
loop_count = 1
|
loop_count = 1
|
||||||
while True:
|
while True:
|
||||||
wait_on = [
|
wait_on = [
|
||||||
(server_name, self.key_downloads[server_name])
|
(server_name, self.key_downloads[server_name])
|
||||||
for server_name in server_to_deferred.keys()
|
for server_name in server_names
|
||||||
if server_name in self.key_downloads
|
if server_name in self.key_downloads
|
||||||
]
|
]
|
||||||
if not wait_on:
|
if not wait_on:
|
||||||
|
@ -314,19 +311,6 @@ class Keyring(object):
|
||||||
|
|
||||||
loop_count += 1
|
loop_count += 1
|
||||||
|
|
||||||
ctx = LoggingContext.current_context()
|
|
||||||
|
|
||||||
def rm(r, server_name_):
|
|
||||||
with PreserveLoggingContext(ctx):
|
|
||||||
logger.debug("Releasing key lookup lock on %s", server_name_)
|
|
||||||
self.key_downloads.pop(server_name_, None)
|
|
||||||
return r
|
|
||||||
|
|
||||||
for server_name, deferred in server_to_deferred.items():
|
|
||||||
logger.debug("Got key lookup lock on %s", server_name)
|
|
||||||
self.key_downloads[server_name] = deferred
|
|
||||||
deferred.addBoth(rm, server_name)
|
|
||||||
|
|
||||||
def _get_server_verify_keys(self, verify_requests):
|
def _get_server_verify_keys(self, verify_requests):
|
||||||
"""Tries to find at least one key for each verify request
|
"""Tries to find at least one key for each verify request
|
||||||
|
|
||||||
|
|
|
@ -86,35 +86,6 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
getattr(LoggingContext.current_context(), "request", None), expected
|
getattr(LoggingContext.current_context(), "request", None), expected
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_wait_for_previous_lookups(self):
|
|
||||||
kr = keyring.Keyring(self.hs)
|
|
||||||
|
|
||||||
lookup_1_deferred = defer.Deferred()
|
|
||||||
lookup_2_deferred = defer.Deferred()
|
|
||||||
|
|
||||||
# we run the lookup in a logcontext so that the patched inlineCallbacks can check
|
|
||||||
# it is doing the right thing with logcontexts.
|
|
||||||
wait_1_deferred = run_in_context(
|
|
||||||
kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
|
|
||||||
)
|
|
||||||
|
|
||||||
# there were no previous lookups, so the deferred should be ready
|
|
||||||
self.successResultOf(wait_1_deferred)
|
|
||||||
|
|
||||||
# set off another wait. It should block because the first lookup
|
|
||||||
# hasn't yet completed.
|
|
||||||
wait_2_deferred = run_in_context(
|
|
||||||
kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertFalse(wait_2_deferred.called)
|
|
||||||
|
|
||||||
# let the first lookup complete (in the sentinel context)
|
|
||||||
lookup_1_deferred.callback(None)
|
|
||||||
|
|
||||||
# now the second wait should complete.
|
|
||||||
self.successResultOf(wait_2_deferred)
|
|
||||||
|
|
||||||
def test_verify_json_objects_for_server_awaits_previous_requests(self):
|
def test_verify_json_objects_for_server_awaits_previous_requests(self):
|
||||||
key1 = signedjson.key.generate_signing_key(1)
|
key1 = signedjson.key.generate_signing_key(1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue