Convert the crypto module to async/await. (#8003)
parent
b6c6fb7950
commit
2a89ce8cd4
|
@ -0,0 +1 @@
|
||||||
|
Convert various parts of the codebase to async/await.
|
|
@ -223,8 +223,7 @@ class Keyring(object):
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _start_key_lookups(self, verify_requests):
|
||||||
def _start_key_lookups(self, verify_requests):
|
|
||||||
"""Sets off the key fetches for each verify request
|
"""Sets off the key fetches for each verify request
|
||||||
|
|
||||||
Once each fetch completes, verify_request.key_ready will be resolved.
|
Once each fetch completes, verify_request.key_ready will be resolved.
|
||||||
|
@ -245,7 +244,7 @@ class Keyring(object):
|
||||||
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||||
|
|
||||||
# Wait for any previous lookups to complete before proceeding.
|
# Wait for any previous lookups to complete before proceeding.
|
||||||
yield self.wait_for_previous_lookups(server_to_request_ids.keys())
|
await self.wait_for_previous_lookups(server_to_request_ids.keys())
|
||||||
|
|
||||||
# take out a lock on each of the servers by sticking a Deferred in
|
# take out a lock on each of the servers by sticking a Deferred in
|
||||||
# key_downloads
|
# key_downloads
|
||||||
|
@ -283,15 +282,14 @@ class Keyring(object):
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error starting key lookups")
|
logger.exception("Error starting key lookups")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def wait_for_previous_lookups(self, server_names) -> None:
|
||||||
def wait_for_previous_lookups(self, server_names):
|
|
||||||
"""Waits for any previous key lookups for the given servers to finish.
|
"""Waits for any previous key lookups for the given servers to finish.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_names (Iterable[str]): list of servers which we want to look up
|
server_names (Iterable[str]): list of servers which we want to look up
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[None]: resolves once all key lookups for the given servers have
|
Resolves once all key lookups for the given servers have
|
||||||
completed. Follows the synapse rules of logcontext preservation.
|
completed. Follows the synapse rules of logcontext preservation.
|
||||||
"""
|
"""
|
||||||
loop_count = 1
|
loop_count = 1
|
||||||
|
@ -309,7 +307,7 @@ class Keyring(object):
|
||||||
loop_count,
|
loop_count,
|
||||||
)
|
)
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
yield defer.DeferredList((w[1] for w in wait_on))
|
await defer.DeferredList((w[1] for w in wait_on))
|
||||||
|
|
||||||
loop_count += 1
|
loop_count += 1
|
||||||
|
|
||||||
|
@ -326,44 +324,44 @@ class Keyring(object):
|
||||||
|
|
||||||
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
|
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def do_iterations():
|
||||||
def do_iterations():
|
try:
|
||||||
with Measure(self.clock, "get_server_verify_keys"):
|
with Measure(self.clock, "get_server_verify_keys"):
|
||||||
for f in self._key_fetchers:
|
for f in self._key_fetchers:
|
||||||
if not remaining_requests:
|
if not remaining_requests:
|
||||||
return
|
return
|
||||||
yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
|
await self._attempt_key_fetches_with_fetcher(
|
||||||
|
f, remaining_requests
|
||||||
# look for any requests which weren't satisfied
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
for verify_request in remaining_requests:
|
|
||||||
verify_request.key_ready.errback(
|
|
||||||
SynapseError(
|
|
||||||
401,
|
|
||||||
"No key for %s with ids in %s (min_validity %i)"
|
|
||||||
% (
|
|
||||||
verify_request.server_name,
|
|
||||||
verify_request.key_ids,
|
|
||||||
verify_request.minimum_valid_until_ts,
|
|
||||||
),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_err(err):
|
# look for any requests which weren't satisfied
|
||||||
# we don't really expect to get here, because any errors should already
|
with PreserveLoggingContext():
|
||||||
# have been caught and logged. But if we do, let's log the error and make
|
for verify_request in remaining_requests:
|
||||||
# sure that all of the deferreds are resolved.
|
verify_request.key_ready.errback(
|
||||||
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
|
SynapseError(
|
||||||
with PreserveLoggingContext():
|
401,
|
||||||
for verify_request in remaining_requests:
|
"No key for %s with ids in %s (min_validity %i)"
|
||||||
if not verify_request.key_ready.called:
|
% (
|
||||||
verify_request.key_ready.errback(err)
|
verify_request.server_name,
|
||||||
|
verify_request.key_ids,
|
||||||
|
verify_request.minimum_valid_until_ts,
|
||||||
|
),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
# we don't really expect to get here, because any errors should already
|
||||||
|
# have been caught and logged. But if we do, let's log the error and make
|
||||||
|
# sure that all of the deferreds are resolved.
|
||||||
|
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
for verify_request in remaining_requests:
|
||||||
|
if not verify_request.key_ready.called:
|
||||||
|
verify_request.key_ready.errback(err)
|
||||||
|
|
||||||
run_in_background(do_iterations).addErrback(on_err)
|
run_in_background(do_iterations)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
|
||||||
def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
|
|
||||||
"""Use a key fetcher to attempt to satisfy some key requests
|
"""Use a key fetcher to attempt to satisfy some key requests
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -390,7 +388,7 @@ class Keyring(object):
|
||||||
verify_request.minimum_valid_until_ts,
|
verify_request.minimum_valid_until_ts,
|
||||||
)
|
)
|
||||||
|
|
||||||
results = yield fetcher.get_keys(missing_keys)
|
results = await fetcher.get_keys(missing_keys)
|
||||||
|
|
||||||
completed = []
|
completed = []
|
||||||
for verify_request in remaining_requests:
|
for verify_request in remaining_requests:
|
||||||
|
@ -423,7 +421,7 @@ class Keyring(object):
|
||||||
|
|
||||||
|
|
||||||
class KeyFetcher(object):
|
class KeyFetcher(object):
|
||||||
def get_keys(self, keys_to_fetch):
|
async def get_keys(self, keys_to_fetch):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
keys_to_fetch (dict[str, dict[str, int]]):
|
keys_to_fetch (dict[str, dict[str, int]]):
|
||||||
|
@ -442,8 +440,7 @@ class StoreKeyFetcher(KeyFetcher):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_keys(self, keys_to_fetch):
|
||||||
def get_keys(self, keys_to_fetch):
|
|
||||||
"""see KeyFetcher.get_keys"""
|
"""see KeyFetcher.get_keys"""
|
||||||
|
|
||||||
keys_to_fetch = (
|
keys_to_fetch = (
|
||||||
|
@ -452,7 +449,7 @@ class StoreKeyFetcher(KeyFetcher):
|
||||||
for key_id in keys_for_server.keys()
|
for key_id in keys_for_server.keys()
|
||||||
)
|
)
|
||||||
|
|
||||||
res = yield self.store.get_server_verify_keys(keys_to_fetch)
|
res = await self.store.get_server_verify_keys(keys_to_fetch)
|
||||||
keys = {}
|
keys = {}
|
||||||
for (server_name, key_id), key in res.items():
|
for (server_name, key_id), key in res.items():
|
||||||
keys.setdefault(server_name, {})[key_id] = key
|
keys.setdefault(server_name, {})[key_id] = key
|
||||||
|
@ -464,8 +461,7 @@ class BaseV2KeyFetcher(object):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.config = hs.get_config()
|
self.config = hs.get_config()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def process_v2_response(self, from_server, response_json, time_added_ms):
|
||||||
def process_v2_response(self, from_server, response_json, time_added_ms):
|
|
||||||
"""Parse a 'Server Keys' structure from the result of a /key request
|
"""Parse a 'Server Keys' structure from the result of a /key request
|
||||||
|
|
||||||
This is used to parse either the entirety of the response from
|
This is used to parse either the entirety of the response from
|
||||||
|
@ -537,7 +533,7 @@ class BaseV2KeyFetcher(object):
|
||||||
|
|
||||||
key_json_bytes = encode_canonical_json(response_json)
|
key_json_bytes = encode_canonical_json(response_json)
|
||||||
|
|
||||||
yield make_deferred_yieldable(
|
await make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
defer.gatherResults(
|
||||||
[
|
[
|
||||||
run_in_background(
|
run_in_background(
|
||||||
|
@ -567,14 +563,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
self.client = hs.get_http_client()
|
self.client = hs.get_http_client()
|
||||||
self.key_servers = self.config.key_servers
|
self.key_servers = self.config.key_servers
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_keys(self, keys_to_fetch):
|
||||||
def get_keys(self, keys_to_fetch):
|
|
||||||
"""see KeyFetcher.get_keys"""
|
"""see KeyFetcher.get_keys"""
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_key(key_server):
|
||||||
def get_key(key_server):
|
|
||||||
try:
|
try:
|
||||||
result = yield self.get_server_verify_key_v2_indirect(
|
result = await self.get_server_verify_key_v2_indirect(
|
||||||
keys_to_fetch, key_server
|
keys_to_fetch, key_server
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
@ -592,7 +586,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
results = yield make_deferred_yieldable(
|
results = await make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
defer.gatherResults(
|
||||||
[run_in_background(get_key, server) for server in self.key_servers],
|
[run_in_background(get_key, server) for server in self.key_servers],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
|
@ -606,8 +600,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
|
|
||||||
return union_of_keys
|
return union_of_keys
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
|
||||||
def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
keys_to_fetch (dict[str, dict[str, int]]):
|
keys_to_fetch (dict[str, dict[str, int]]):
|
||||||
|
@ -617,7 +610,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
the keys
|
the keys
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
|
dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
|
||||||
from server_name -> key_id -> FetchKeyResult
|
from server_name -> key_id -> FetchKeyResult
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -632,20 +625,18 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query_response = yield defer.ensureDeferred(
|
query_response = await self.client.post_json(
|
||||||
self.client.post_json(
|
destination=perspective_name,
|
||||||
destination=perspective_name,
|
path="/_matrix/key/v2/query",
|
||||||
path="/_matrix/key/v2/query",
|
data={
|
||||||
data={
|
"server_keys": {
|
||||||
"server_keys": {
|
server_name: {
|
||||||
server_name: {
|
key_id: {"minimum_valid_until_ts": min_valid_ts}
|
||||||
key_id: {"minimum_valid_until_ts": min_valid_ts}
|
for key_id, min_valid_ts in server_keys.items()
|
||||||
for key_id, min_valid_ts in server_keys.items()
|
|
||||||
}
|
|
||||||
for server_name, server_keys in keys_to_fetch.items()
|
|
||||||
}
|
}
|
||||||
},
|
for server_name, server_keys in keys_to_fetch.items()
|
||||||
)
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
except (NotRetryingDestination, RequestSendFailed) as e:
|
except (NotRetryingDestination, RequestSendFailed) as e:
|
||||||
# these both have str() representations which we can't really improve upon
|
# these both have str() representations which we can't really improve upon
|
||||||
|
@ -670,7 +661,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
try:
|
try:
|
||||||
self._validate_perspectives_response(key_server, response)
|
self._validate_perspectives_response(key_server, response)
|
||||||
|
|
||||||
processed_response = yield self.process_v2_response(
|
processed_response = await self.process_v2_response(
|
||||||
perspective_name, response, time_added_ms=time_now_ms
|
perspective_name, response, time_added_ms=time_now_ms
|
||||||
)
|
)
|
||||||
except KeyLookupError as e:
|
except KeyLookupError as e:
|
||||||
|
@ -689,7 +680,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
)
|
)
|
||||||
keys.setdefault(server_name, {}).update(processed_response)
|
keys.setdefault(server_name, {}).update(processed_response)
|
||||||
|
|
||||||
yield self.store.store_server_verify_keys(
|
await self.store.store_server_verify_keys(
|
||||||
perspective_name, time_now_ms, added_keys
|
perspective_name, time_now_ms, added_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -741,24 +732,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.client = hs.get_http_client()
|
self.client = hs.get_http_client()
|
||||||
|
|
||||||
def get_keys(self, keys_to_fetch):
|
async def get_keys(self, keys_to_fetch):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
keys_to_fetch (dict[str, iterable[str]]):
|
keys_to_fetch (dict[str, iterable[str]]):
|
||||||
the keys to be fetched. server_name -> key_ids
|
the keys to be fetched. server_name -> key_ids
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
|
dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
|
||||||
map from server_name -> key_id -> FetchKeyResult
|
map from server_name -> key_id -> FetchKeyResult
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_key(key_to_fetch_item):
|
||||||
def get_key(key_to_fetch_item):
|
|
||||||
server_name, key_ids = key_to_fetch_item
|
server_name, key_ids = key_to_fetch_item
|
||||||
try:
|
try:
|
||||||
keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
|
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
|
||||||
results[server_name] = keys
|
results[server_name] = keys
|
||||||
except KeyLookupError as e:
|
except KeyLookupError as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -767,12 +757,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error getting keys %s from %s", key_ids, server_name)
|
logger.exception("Error getting keys %s from %s", key_ids, server_name)
|
||||||
|
|
||||||
return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
|
return await yieldable_gather_results(
|
||||||
lambda _: results
|
get_key, keys_to_fetch.items()
|
||||||
)
|
).addCallback(lambda _: results)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
||||||
def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -794,25 +783,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||||
|
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
try:
|
try:
|
||||||
response = yield defer.ensureDeferred(
|
response = await self.client.get_json(
|
||||||
self.client.get_json(
|
destination=server_name,
|
||||||
destination=server_name,
|
path="/_matrix/key/v2/server/"
|
||||||
path="/_matrix/key/v2/server/"
|
+ urllib.parse.quote(requested_key_id),
|
||||||
+ urllib.parse.quote(requested_key_id),
|
ignore_backoff=True,
|
||||||
ignore_backoff=True,
|
# we only give the remote server 10s to respond. It should be an
|
||||||
# we only give the remote server 10s to respond. It should be an
|
# easy request to handle, so if it doesn't reply within 10s, it's
|
||||||
# easy request to handle, so if it doesn't reply within 10s, it's
|
# probably not going to.
|
||||||
# probably not going to.
|
#
|
||||||
#
|
# Furthermore, when we are acting as a notary server, we cannot
|
||||||
# Furthermore, when we are acting as a notary server, we cannot
|
# wait all day for all of the origin servers, as the requesting
|
||||||
# wait all day for all of the origin servers, as the requesting
|
# server will otherwise time out before we can respond.
|
||||||
# server will otherwise time out before we can respond.
|
#
|
||||||
#
|
# (Note that get_json may make 4 attempts, so this can still take
|
||||||
# (Note that get_json may make 4 attempts, so this can still take
|
# almost 45 seconds to fetch the headers, plus up to another 60s to
|
||||||
# almost 45 seconds to fetch the headers, plus up to another 60s to
|
# read the response).
|
||||||
# read the response).
|
timeout=10000,
|
||||||
timeout=10000,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except (NotRetryingDestination, RequestSendFailed) as e:
|
except (NotRetryingDestination, RequestSendFailed) as e:
|
||||||
# these both have str() representations which we can't really improve
|
# these both have str() representations which we can't really improve
|
||||||
|
@ -827,12 +814,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||||
% (server_name, response["server_name"])
|
% (server_name, response["server_name"])
|
||||||
)
|
)
|
||||||
|
|
||||||
response_keys = yield self.process_v2_response(
|
response_keys = await self.process_v2_response(
|
||||||
from_server=server_name,
|
from_server=server_name,
|
||||||
response_json=response,
|
response_json=response,
|
||||||
time_added_ms=time_now_ms,
|
time_added_ms=time_now_ms,
|
||||||
)
|
)
|
||||||
yield self.store.store_server_verify_keys(
|
await self.store.store_server_verify_keys(
|
||||||
server_name,
|
server_name,
|
||||||
time_now_ms,
|
time_now_ms,
|
||||||
((server_name, key_id, key) for key_id, key in response_keys.items()),
|
((server_name, key_id, key) for key_id, key in response_keys.items()),
|
||||||
|
@ -842,22 +829,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _handle_key_deferred(verify_request) -> None:
|
||||||
def _handle_key_deferred(verify_request):
|
|
||||||
"""Waits for the key to become available, and then performs a verification
|
"""Waits for the key to become available, and then performs a verification
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
verify_request (VerifyJsonRequest):
|
verify_request (VerifyJsonRequest):
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred[None]
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if there was a problem performing the verification
|
SynapseError if there was a problem performing the verification
|
||||||
"""
|
"""
|
||||||
server_name = verify_request.server_name
|
server_name = verify_request.server_name
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
_, key_id, verify_key = yield verify_request.key_ready
|
_, key_id, verify_key = await verify_request.key_ready
|
||||||
|
|
||||||
json_object = verify_request.json_object
|
json_object = verify_request.json_object
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,7 @@ from synapse.logging.context import (
|
||||||
from synapse.storage.keys import FetchKeyResult
|
from synapse.storage.keys import FetchKeyResult
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.test_utils import make_awaitable
|
||||||
|
|
||||||
|
|
||||||
class MockPerspectiveServer(object):
|
class MockPerspectiveServer(object):
|
||||||
|
@ -201,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
with a null `ts_valid_until_ms`
|
with a null `ts_valid_until_ms`
|
||||||
"""
|
"""
|
||||||
mock_fetcher = keyring.KeyFetcher()
|
mock_fetcher = keyring.KeyFetcher()
|
||||||
mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))
|
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
|
||||||
|
|
||||||
kr = keyring.Keyring(
|
kr = keyring.Keyring(
|
||||||
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
|
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
|
||||||
|
@ -244,17 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
"""Two requests for the same key should be deduped."""
|
"""Two requests for the same key should be deduped."""
|
||||||
key1 = signedjson.key.generate_signing_key(1)
|
key1 = signedjson.key.generate_signing_key(1)
|
||||||
|
|
||||||
def get_keys(keys_to_fetch):
|
async def get_keys(keys_to_fetch):
|
||||||
# there should only be one request object (with the max validity)
|
# there should only be one request object (with the max validity)
|
||||||
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
||||||
|
|
||||||
return defer.succeed(
|
return {
|
||||||
{
|
"server1": {
|
||||||
"server1": {
|
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
|
||||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
)
|
}
|
||||||
|
|
||||||
mock_fetcher = keyring.KeyFetcher()
|
mock_fetcher = keyring.KeyFetcher()
|
||||||
mock_fetcher.get_keys = Mock(side_effect=get_keys)
|
mock_fetcher.get_keys = Mock(side_effect=get_keys)
|
||||||
|
@ -281,25 +280,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
"""If the first fetcher cannot provide a recent enough key, we fall back"""
|
"""If the first fetcher cannot provide a recent enough key, we fall back"""
|
||||||
key1 = signedjson.key.generate_signing_key(1)
|
key1 = signedjson.key.generate_signing_key(1)
|
||||||
|
|
||||||
def get_keys1(keys_to_fetch):
|
async def get_keys1(keys_to_fetch):
|
||||||
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
||||||
return defer.succeed(
|
return {
|
||||||
{
|
"server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
|
||||||
"server1": {
|
}
|
||||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_keys2(keys_to_fetch):
|
async def get_keys2(keys_to_fetch):
|
||||||
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
||||||
return defer.succeed(
|
return {
|
||||||
{
|
"server1": {
|
||||||
"server1": {
|
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
|
||||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
)
|
}
|
||||||
|
|
||||||
mock_fetcher1 = keyring.KeyFetcher()
|
mock_fetcher1 = keyring.KeyFetcher()
|
||||||
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
|
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
|
||||||
|
|
Loading…
Reference in New Issue