Refactor Keyring._start_key_lookups

There's an awful lot of deferreds and dictionaries flying around here. The
whole thing can be made much simpler and achieve the same effect.
pull/5724/head
Richard van der Hoff 2019-07-19 17:49:19 +01:00
parent 356ed0438e
commit c7095be913
2 changed files with 34 additions and 79 deletions

View File

@ -238,27 +238,9 @@ class Keyring(object):
""" """
try: try:
# create a deferred for each server we're going to look up the keys ctx = LoggingContext.current_context()
# for; we'll resolve them once we have completed our lookups.
# These will be passed into wait_for_previous_lookups to block
# any other lookups until we have finished.
# The deferreds are called with no logcontext.
server_to_deferred = {
rq.server_name: defer.Deferred() for rq in verify_requests
}
# We want to wait for any previous lookups to complete before # map from server name to a set of outstanding request ids
# proceeding.
yield self.wait_for_previous_lookups(server_to_deferred)
# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
#
# map from server name to a set of request ids
server_to_request_ids = {} server_to_request_ids = {}
for verify_request in verify_requests: for verify_request in verify_requests:
@ -266,40 +248,55 @@ class Keyring(object):
request_id = id(verify_request) request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id) server_to_request_ids.setdefault(server_name, set()).add(request_id)
def remove_deferreds(res, verify_request): # Wait for any previous lookups to complete before proceeding.
yield self.wait_for_previous_lookups(server_to_request_ids.keys())
# take out a lock on each of the servers by sticking a Deferred in
# key_downloads
for server_name in server_to_request_ids.keys():
self.key_downloads[server_name] = defer.Deferred()
logger.debug("Got key lookup lock on %s", server_name)
# When we've finished fetching all the keys for a given server_name,
# drop the lock by resolving the deferred in key_downloads.
def lookup_done(res, verify_request):
server_name = verify_request.server_name server_name = verify_request.server_name
request_id = id(verify_request) server_requests = server_to_request_ids[server_name]
server_to_request_ids[server_name].discard(request_id) server_requests.remove(id(verify_request))
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None) # if there are no more requests for this server, we can drop the lock.
if d: if not server_requests:
d.callback(None) with PreserveLoggingContext(ctx):
logger.debug("Releasing key lookup lock on %s", server_name)
d = self.key_downloads.pop(server_name)
d.callback(None)
return res return res
for verify_request in verify_requests: for verify_request in verify_requests:
verify_request.key_ready.addBoth(remove_deferreds, verify_request) verify_request.key_ready.addBoth(lookup_done, verify_request)
# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)
except Exception: except Exception:
logger.exception("Error starting key lookups") logger.exception("Error starting key lookups")
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_previous_lookups(self, server_to_deferred): def wait_for_previous_lookups(self, server_names):
"""Waits for any previous key lookups for the given servers to finish. """Waits for any previous key lookups for the given servers to finish.
Args: Args:
server_to_deferred (dict[str, Deferred]): server_name to deferred which gets server_names (Iterable[str]): list of servers which we want to look up
resolved once we've finished looking up keys for that server.
The Deferreds should be regular twisted ones which call their
callbacks with no logcontext.
Returns: a Deferred which resolves once all key lookups for the given Returns:
servers have completed. Follows the synapse rules of logcontext Deferred[None]: resolves once all key lookups for the given servers have
preservation. completed. Follows the synapse rules of logcontext preservation.
""" """
loop_count = 1 loop_count = 1
while True: while True:
wait_on = [ wait_on = [
(server_name, self.key_downloads[server_name]) (server_name, self.key_downloads[server_name])
for server_name in server_to_deferred.keys() for server_name in server_names
if server_name in self.key_downloads if server_name in self.key_downloads
] ]
if not wait_on: if not wait_on:
@ -314,19 +311,6 @@ class Keyring(object):
loop_count += 1 loop_count += 1
ctx = LoggingContext.current_context()
def rm(r, server_name_):
with PreserveLoggingContext(ctx):
logger.debug("Releasing key lookup lock on %s", server_name_)
self.key_downloads.pop(server_name_, None)
return r
for server_name, deferred in server_to_deferred.items():
logger.debug("Got key lookup lock on %s", server_name)
self.key_downloads[server_name] = deferred
deferred.addBoth(rm, server_name)
def _get_server_verify_keys(self, verify_requests): def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request """Tries to find at least one key for each verify request

View File

@ -86,35 +86,6 @@ class KeyringTestCase(unittest.HomeserverTestCase):
getattr(LoggingContext.current_context(), "request", None), expected getattr(LoggingContext.current_context(), "request", None), expected
) )
def test_wait_for_previous_lookups(self):
kr = keyring.Keyring(self.hs)
lookup_1_deferred = defer.Deferred()
lookup_2_deferred = defer.Deferred()
# we run the lookup in a logcontext so that the patched inlineCallbacks can check
# it is doing the right thing with logcontexts.
wait_1_deferred = run_in_context(
kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
)
# there were no previous lookups, so the deferred should be ready
self.successResultOf(wait_1_deferred)
# set off another wait. It should block because the first lookup
# hasn't yet completed.
wait_2_deferred = run_in_context(
kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
)
self.assertFalse(wait_2_deferred.called)
# let the first lookup complete (in the sentinel context)
lookup_1_deferred.callback(None)
# now the second wait should complete.
self.successResultOf(wait_2_deferred)
def test_verify_json_objects_for_server_awaits_previous_requests(self): def test_verify_json_objects_for_server_awaits_previous_requests(self):
key1 = signedjson.key.generate_signing_key(1) key1 = signedjson.key.generate_signing_key(1)