Convert the crypto module to async/await. (#8003)

pull/8027/head
Patrick Cloke 2020-08-03 08:29:01 -04:00 committed by GitHub
parent b6c6fb7950
commit 2a89ce8cd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 110 additions and 133 deletions

1
changelog.d/8003.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -223,8 +223,7 @@ class Keyring(object):
return results
@defer.inlineCallbacks
def _start_key_lookups(self, verify_requests):
async def _start_key_lookups(self, verify_requests):
"""Sets off the key fetches for each verify request
Once each fetch completes, verify_request.key_ready will be resolved.
@ -245,7 +244,7 @@ class Keyring(object):
server_to_request_ids.setdefault(server_name, set()).add(request_id)
# Wait for any previous lookups to complete before proceeding.
yield self.wait_for_previous_lookups(server_to_request_ids.keys())
await self.wait_for_previous_lookups(server_to_request_ids.keys())
# take out a lock on each of the servers by sticking a Deferred in
# key_downloads
@ -283,15 +282,14 @@ class Keyring(object):
except Exception:
logger.exception("Error starting key lookups")
@defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names):
async def wait_for_previous_lookups(self, server_names) -> None:
"""Waits for any previous key lookups for the given servers to finish.
Args:
server_names (Iterable[str]): list of servers which we want to look up
Returns:
Deferred[None]: resolves once all key lookups for the given servers have
Resolves once all key lookups for the given servers have
completed. Follows the synapse rules of logcontext preservation.
"""
loop_count = 1
@ -309,7 +307,7 @@ class Keyring(object):
loop_count,
)
with PreserveLoggingContext():
yield defer.DeferredList((w[1] for w in wait_on))
await defer.DeferredList((w[1] for w in wait_on))
loop_count += 1
@ -326,13 +324,15 @@ class Keyring(object):
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
@defer.inlineCallbacks
def do_iterations():
async def do_iterations():
try:
with Measure(self.clock, "get_server_verify_keys"):
for f in self._key_fetchers:
if not remaining_requests:
return
yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
await self._attempt_key_fetches_with_fetcher(
f, remaining_requests
)
# look for any requests which weren't satisfied
with PreserveLoggingContext():
@ -349,8 +349,7 @@ class Keyring(object):
Codes.UNAUTHORIZED,
)
)
def on_err(err):
except Exception as err:
# we don't really expect to get here, because any errors should already
# have been caught and logged. But if we do, let's log the error and make
# sure that all of the deferreds are resolved.
@ -360,10 +359,9 @@ class Keyring(object):
if not verify_request.key_ready.called:
verify_request.key_ready.errback(err)
run_in_background(do_iterations).addErrback(on_err)
run_in_background(do_iterations)
@defer.inlineCallbacks
def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
"""Use a key fetcher to attempt to satisfy some key requests
Args:
@ -390,7 +388,7 @@ class Keyring(object):
verify_request.minimum_valid_until_ts,
)
results = yield fetcher.get_keys(missing_keys)
results = await fetcher.get_keys(missing_keys)
completed = []
for verify_request in remaining_requests:
@ -423,7 +421,7 @@ class Keyring(object):
class KeyFetcher(object):
def get_keys(self, keys_to_fetch):
async def get_keys(self, keys_to_fetch):
"""
Args:
keys_to_fetch (dict[str, dict[str, int]]):
@ -442,8 +440,7 @@ class StoreKeyFetcher(KeyFetcher):
def __init__(self, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_keys(self, keys_to_fetch):
async def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
keys_to_fetch = (
@ -452,7 +449,7 @@ class StoreKeyFetcher(KeyFetcher):
for key_id in keys_for_server.keys()
)
res = yield self.store.get_server_verify_keys(keys_to_fetch)
res = await self.store.get_server_verify_keys(keys_to_fetch)
keys = {}
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
@ -464,8 +461,7 @@ class BaseV2KeyFetcher(object):
self.store = hs.get_datastore()
self.config = hs.get_config()
@defer.inlineCallbacks
def process_v2_response(self, from_server, response_json, time_added_ms):
async def process_v2_response(self, from_server, response_json, time_added_ms):
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
@ -537,7 +533,7 @@ class BaseV2KeyFetcher(object):
key_json_bytes = encode_canonical_json(response_json)
yield make_deferred_yieldable(
await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@ -567,14 +563,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_http_client()
self.key_servers = self.config.key_servers
@defer.inlineCallbacks
def get_keys(self, keys_to_fetch):
async def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
@defer.inlineCallbacks
def get_key(key_server):
async def get_key(key_server):
try:
result = yield self.get_server_verify_key_v2_indirect(
result = await self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server
)
return result
@ -592,7 +586,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return {}
results = yield make_deferred_yieldable(
results = await make_deferred_yieldable(
defer.gatherResults(
[run_in_background(get_key, server) for server in self.key_servers],
consumeErrors=True,
@ -606,8 +600,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return union_of_keys
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
"""
Args:
keys_to_fetch (dict[str, dict[str, int]]):
@ -617,7 +610,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
the keys
Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
from server_name -> key_id -> FetchKeyResult
Raises:
@ -632,8 +625,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
)
try:
query_response = yield defer.ensureDeferred(
self.client.post_json(
query_response = await self.client.post_json(
destination=perspective_name,
path="/_matrix/key/v2/query",
data={
@ -646,7 +638,6 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
}
},
)
)
except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve upon
raise KeyLookupError(str(e))
@ -670,7 +661,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
try:
self._validate_perspectives_response(key_server, response)
processed_response = yield self.process_v2_response(
processed_response = await self.process_v2_response(
perspective_name, response, time_added_ms=time_now_ms
)
except KeyLookupError as e:
@ -689,7 +680,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
)
keys.setdefault(server_name, {}).update(processed_response)
yield self.store.store_server_verify_keys(
await self.store.store_server_verify_keys(
perspective_name, time_now_ms, added_keys
)
@ -741,24 +732,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.clock = hs.get_clock()
self.client = hs.get_http_client()
def get_keys(self, keys_to_fetch):
async def get_keys(self, keys_to_fetch):
"""
Args:
keys_to_fetch (dict[str, iterable[str]]):
the keys to be fetched. server_name -> key_ids
Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
map from server_name -> key_id -> FetchKeyResult
"""
results = {}
@defer.inlineCallbacks
def get_key(key_to_fetch_item):
async def get_key(key_to_fetch_item):
server_name, key_ids = key_to_fetch_item
try:
keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
results[server_name] = keys
except KeyLookupError as e:
logger.warning(
@ -767,12 +757,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name)
return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
lambda _: results
)
return await yieldable_gather_results(
get_key, keys_to_fetch.items()
).addCallback(lambda _: results)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
async def get_server_verify_key_v2_direct(self, server_name, key_ids):
"""
Args:
@ -794,8 +783,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
time_now_ms = self.clock.time_msec()
try:
response = yield defer.ensureDeferred(
self.client.get_json(
response = await self.client.get_json(
destination=server_name,
path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id),
@ -813,7 +801,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
# read the response).
timeout=10000,
)
)
except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve
# upon
@ -827,12 +814,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
% (server_name, response["server_name"])
)
response_keys = yield self.process_v2_response(
response_keys = await self.process_v2_response(
from_server=server_name,
response_json=response,
time_added_ms=time_now_ms,
)
yield self.store.store_server_verify_keys(
await self.store.store_server_verify_keys(
server_name,
time_now_ms,
((server_name, key_id, key) for key_id, key in response_keys.items()),
@ -842,22 +829,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
return keys
@defer.inlineCallbacks
def _handle_key_deferred(verify_request):
async def _handle_key_deferred(verify_request) -> None:
"""Waits for the key to become available, and then performs a verification
Args:
verify_request (VerifyJsonRequest):
Returns:
Deferred[None]
Raises:
SynapseError if there was a problem performing the verification
"""
server_name = verify_request.server_name
with PreserveLoggingContext():
_, key_id, verify_key = yield verify_request.key_ready
_, key_id, verify_key = await verify_request.key_ready
json_object = verify_request.json_object

View File

@ -40,6 +40,7 @@ from synapse.logging.context import (
from synapse.storage.keys import FetchKeyResult
from tests import unittest
from tests.test_utils import make_awaitable
class MockPerspectiveServer(object):
@ -201,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
with a null `ts_valid_until_ms`
"""
mock_fetcher = keyring.KeyFetcher()
mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
kr = keyring.Keyring(
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
@ -244,17 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key(1)
def get_keys(keys_to_fetch):
async def get_keys(keys_to_fetch):
# there should only be one request object (with the max validity)
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed(
{
return {
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
}
)
mock_fetcher = keyring.KeyFetcher()
mock_fetcher.get_keys = Mock(side_effect=get_keys)
@ -281,25 +280,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""If the first fetcher cannot provide a recent enough key, we fall back"""
key1 = signedjson.key.generate_signing_key(1)
def get_keys1(keys_to_fetch):
async def get_keys1(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed(
{
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
return {
"server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
}
}
)
def get_keys2(keys_to_fetch):
async def get_keys2(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed(
{
return {
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
}
)
mock_fetcher1 = keyring.KeyFetcher()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)