Fix stack overflow in Keyring (#5724)
* Refactor Keyring._start_key_lookups There's an awful lot of deferreds and dictionaries flying around here. The whole thing can be made much simpler and achieve the same effect. * Add a delay to key lookup lock release to fix stack overflow A tactical call_later here should fix #5723 * changelogpull/5733/head
						commit
						0cb72812f9
					
				|  | @ -0,0 +1 @@ | |||
| Fix stack overflow in server key lookup code. | ||||
|  | @ -238,27 +238,9 @@ class Keyring(object): | |||
|         """ | ||||
| 
 | ||||
|         try: | ||||
|             # create a deferred for each server we're going to look up the keys | ||||
|             # for; we'll resolve them once we have completed our lookups. | ||||
|             # These will be passed into wait_for_previous_lookups to block | ||||
|             # any other lookups until we have finished. | ||||
|             # The deferreds are called with no logcontext. | ||||
|             server_to_deferred = { | ||||
|                 rq.server_name: defer.Deferred() for rq in verify_requests | ||||
|             } | ||||
|             ctx = LoggingContext.current_context() | ||||
| 
 | ||||
|             # We want to wait for any previous lookups to complete before | ||||
|             # proceeding. | ||||
|             yield self.wait_for_previous_lookups(server_to_deferred) | ||||
| 
 | ||||
|             # Actually start fetching keys. | ||||
|             self._get_server_verify_keys(verify_requests) | ||||
| 
 | ||||
|             # When we've finished fetching all the keys for a given server_name, | ||||
|             # resolve the deferred passed to `wait_for_previous_lookups` so that | ||||
|             # any lookups waiting will proceed. | ||||
|             # | ||||
|             # map from server name to a set of request ids | ||||
|             # map from server name to a set of outstanding request ids | ||||
|             server_to_request_ids = {} | ||||
| 
 | ||||
|             for verify_request in verify_requests: | ||||
|  | @ -266,40 +248,61 @@ class Keyring(object): | |||
|                 request_id = id(verify_request) | ||||
|                 server_to_request_ids.setdefault(server_name, set()).add(request_id) | ||||
| 
 | ||||
|             def remove_deferreds(res, verify_request): | ||||
|             # Wait for any previous lookups to complete before proceeding. | ||||
|             yield self.wait_for_previous_lookups(server_to_request_ids.keys()) | ||||
| 
 | ||||
|             # take out a lock on each of the servers by sticking a Deferred in | ||||
|             # key_downloads | ||||
|             for server_name in server_to_request_ids.keys(): | ||||
|                 self.key_downloads[server_name] = defer.Deferred() | ||||
|                 logger.debug("Got key lookup lock on %s", server_name) | ||||
| 
 | ||||
|             # When we've finished fetching all the keys for a given server_name, | ||||
|             # drop the lock by resolving the deferred in key_downloads. | ||||
|             def drop_server_lock(server_name): | ||||
|                 d = self.key_downloads.pop(server_name) | ||||
|                 d.callback(None) | ||||
| 
 | ||||
|             def lookup_done(res, verify_request): | ||||
|                 server_name = verify_request.server_name | ||||
|                 request_id = id(verify_request) | ||||
|                 server_to_request_ids[server_name].discard(request_id) | ||||
|                 if not server_to_request_ids[server_name]: | ||||
|                     d = server_to_deferred.pop(server_name, None) | ||||
|                     if d: | ||||
|                         d.callback(None) | ||||
|                 server_requests = server_to_request_ids[server_name] | ||||
|                 server_requests.remove(id(verify_request)) | ||||
| 
 | ||||
|                 # if there are no more requests for this server, we can drop the lock. | ||||
|                 if not server_requests: | ||||
|                     with PreserveLoggingContext(ctx): | ||||
|                         logger.debug("Releasing key lookup lock on %s", server_name) | ||||
| 
 | ||||
|                     # ... but not immediately, as that can cause stack explosions if | ||||
|                     # we get a long queue of lookups. | ||||
|                     self.clock.call_later(0, drop_server_lock, server_name) | ||||
| 
 | ||||
|                 return res | ||||
| 
 | ||||
|             for verify_request in verify_requests: | ||||
|                 verify_request.key_ready.addBoth(remove_deferreds, verify_request) | ||||
|                 verify_request.key_ready.addBoth(lookup_done, verify_request) | ||||
| 
 | ||||
|             # Actually start fetching keys. | ||||
|             self._get_server_verify_keys(verify_requests) | ||||
|         except Exception: | ||||
|             logger.exception("Error starting key lookups") | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def wait_for_previous_lookups(self, server_to_deferred): | ||||
|     def wait_for_previous_lookups(self, server_names): | ||||
|         """Waits for any previous key lookups for the given servers to finish. | ||||
| 
 | ||||
|         Args: | ||||
|             server_to_deferred (dict[str, Deferred]): server_name to deferred which gets | ||||
|                 resolved once we've finished looking up keys for that server. | ||||
|                 The Deferreds should be regular twisted ones which call their | ||||
|                 callbacks with no logcontext. | ||||
|             server_names (Iterable[str]): list of servers which we want to look up | ||||
| 
 | ||||
|         Returns: a Deferred which resolves once all key lookups for the given | ||||
|             servers have completed. Follows the synapse rules of logcontext | ||||
|             preservation. | ||||
|         Returns: | ||||
|             Deferred[None]: resolves once all key lookups for the given servers have | ||||
|                 completed. Follows the synapse rules of logcontext preservation. | ||||
|         """ | ||||
|         loop_count = 1 | ||||
|         while True: | ||||
|             wait_on = [ | ||||
|                 (server_name, self.key_downloads[server_name]) | ||||
|                 for server_name in server_to_deferred.keys() | ||||
|                 for server_name in server_names | ||||
|                 if server_name in self.key_downloads | ||||
|             ] | ||||
|             if not wait_on: | ||||
|  | @ -314,19 +317,6 @@ class Keyring(object): | |||
| 
 | ||||
|             loop_count += 1 | ||||
| 
 | ||||
|         ctx = LoggingContext.current_context() | ||||
| 
 | ||||
|         def rm(r, server_name_): | ||||
|             with PreserveLoggingContext(ctx): | ||||
|                 logger.debug("Releasing key lookup lock on %s", server_name_) | ||||
|                 self.key_downloads.pop(server_name_, None) | ||||
|             return r | ||||
| 
 | ||||
|         for server_name, deferred in server_to_deferred.items(): | ||||
|             logger.debug("Got key lookup lock on %s", server_name) | ||||
|             self.key_downloads[server_name] = deferred | ||||
|             deferred.addBoth(rm, server_name) | ||||
| 
 | ||||
|     def _get_server_verify_keys(self, verify_requests): | ||||
|         """Tries to find at least one key for each verify request | ||||
| 
 | ||||
|  |  | |||
|  | @ -86,35 +86,6 @@ class KeyringTestCase(unittest.HomeserverTestCase): | |||
|             getattr(LoggingContext.current_context(), "request", None), expected | ||||
|         ) | ||||
| 
 | ||||
|     def test_wait_for_previous_lookups(self): | ||||
|         kr = keyring.Keyring(self.hs) | ||||
| 
 | ||||
|         lookup_1_deferred = defer.Deferred() | ||||
|         lookup_2_deferred = defer.Deferred() | ||||
| 
 | ||||
|         # we run the lookup in a logcontext so that the patched inlineCallbacks can check | ||||
|         # it is doing the right thing with logcontexts. | ||||
|         wait_1_deferred = run_in_context( | ||||
|             kr.wait_for_previous_lookups, {"server1": lookup_1_deferred} | ||||
|         ) | ||||
| 
 | ||||
|         # there were no previous lookups, so the deferred should be ready | ||||
|         self.successResultOf(wait_1_deferred) | ||||
| 
 | ||||
|         # set off another wait. It should block because the first lookup | ||||
|         # hasn't yet completed. | ||||
|         wait_2_deferred = run_in_context( | ||||
|             kr.wait_for_previous_lookups, {"server1": lookup_2_deferred} | ||||
|         ) | ||||
| 
 | ||||
|         self.assertFalse(wait_2_deferred.called) | ||||
| 
 | ||||
|         # let the first lookup complete (in the sentinel context) | ||||
|         lookup_1_deferred.callback(None) | ||||
| 
 | ||||
|         # now the second wait should complete. | ||||
|         self.successResultOf(wait_2_deferred) | ||||
| 
 | ||||
|     def test_verify_json_objects_for_server_awaits_previous_requests(self): | ||||
|         key1 = signedjson.key.generate_signing_key(1) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Richard van der Hoff
						Richard van der Hoff