Factor out KeyFetchers from KeyRing
Rather than have three methods which have to have the same interface, factor out a separate interface which is provided by three implementations. I find it easier to grok the code this way.pull/5244/head
							parent
							
								
									b75537beaf
								
							
						
					
					
						commit
						895b79ac2e
					
				|  | @ -0,0 +1 @@ | |||
| Refactor synapse.crypto.keyring to use a KeyFetcher interface. | ||||
|  | @ -80,12 +80,13 @@ class KeyLookupError(ValueError): | |||
| 
 | ||||
| class Keyring(object): | ||||
|     def __init__(self, hs): | ||||
|         self.store = hs.get_datastore() | ||||
|         self.clock = hs.get_clock() | ||||
|         self.client = hs.get_http_client() | ||||
|         self.config = hs.get_config() | ||||
|         self.perspective_servers = self.config.perspectives | ||||
|         self.hs = hs | ||||
| 
 | ||||
|         self._key_fetchers = ( | ||||
|             StoreKeyFetcher(hs), | ||||
|             PerspectivesKeyFetcher(hs), | ||||
|             ServerKeyFetcher(hs), | ||||
|         ) | ||||
| 
 | ||||
|         # map from server name to Deferred. Has an entry for each server with | ||||
|         # an ongoing key download; the Deferred completes once the download | ||||
|  | @ -271,13 +272,6 @@ class Keyring(object): | |||
|             verify_requests (list[VerifyKeyRequest]): list of verify requests | ||||
|         """ | ||||
| 
 | ||||
|         # These are functions that produce keys given a list of key ids | ||||
|         key_fetch_fns = ( | ||||
|             self.get_keys_from_store,  # First try the local store | ||||
|             self.get_keys_from_perspectives,  # Then try via perspectives | ||||
|             self.get_keys_from_server,  # Then try directly | ||||
|         ) | ||||
| 
 | ||||
|         @defer.inlineCallbacks | ||||
|         def do_iterations(): | ||||
|             with Measure(self.clock, "get_server_verify_keys"): | ||||
|  | @ -288,8 +282,8 @@ class Keyring(object): | |||
|                         verify_request.key_ids | ||||
|                     ) | ||||
| 
 | ||||
|                 for fn in key_fetch_fns: | ||||
|                     results = yield fn(missing_keys.items()) | ||||
|                 for f in self._key_fetchers: | ||||
|                     results = yield f.get_keys(missing_keys.items()) | ||||
| 
 | ||||
|                     # We now need to figure out which verify requests we have keys | ||||
|                     # for and which we don't | ||||
|  | @ -348,8 +342,9 @@ class Keyring(object): | |||
| 
 | ||||
|         run_in_background(do_iterations).addErrback(on_err) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_keys_from_store(self, server_name_and_key_ids): | ||||
| 
 | ||||
| class KeyFetcher(object): | ||||
|     def get_keys(self, server_name_and_key_ids): | ||||
|         """ | ||||
|         Args: | ||||
|             server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]): | ||||
|  | @ -359,6 +354,18 @@ class Keyring(object): | |||
|             Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: | ||||
|                 map from server_name -> key_id -> FetchKeyResult | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
| 
 | ||||
| class StoreKeyFetcher(KeyFetcher): | ||||
|     """KeyFetcher impl which fetches keys from our data store""" | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         self.store = hs.get_datastore() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_keys(self, server_name_and_key_ids): | ||||
|         """see KeyFetcher.get_keys""" | ||||
|         keys_to_fetch = ( | ||||
|             (server_name, key_id) | ||||
|             for server_name, key_ids in server_name_and_key_ids | ||||
|  | @ -370,203 +377,11 @@ class Keyring(object): | |||
|             keys.setdefault(server_name, {})[key_id] = key | ||||
|         defer.returnValue(keys) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_keys_from_perspectives(self, server_name_and_key_ids): | ||||
|         @defer.inlineCallbacks | ||||
|         def get_key(perspective_name, perspective_keys): | ||||
|             try: | ||||
|                 result = yield self.get_server_verify_key_v2_indirect( | ||||
|                     server_name_and_key_ids, perspective_name, perspective_keys | ||||
|                 ) | ||||
|                 defer.returnValue(result) | ||||
|             except KeyLookupError as e: | ||||
|                 logger.warning("Key lookup failed from %r: %s", perspective_name, e) | ||||
|             except Exception as e: | ||||
|                 logger.exception( | ||||
|                     "Unable to get key from %r: %s %s", | ||||
|                     perspective_name, | ||||
|                     type(e).__name__, | ||||
|                     str(e), | ||||
|                 ) | ||||
| 
 | ||||
|             defer.returnValue({}) | ||||
| 
 | ||||
|         results = yield logcontext.make_deferred_yieldable( | ||||
|             defer.gatherResults( | ||||
|                 [ | ||||
|                     run_in_background(get_key, p_name, p_keys) | ||||
|                     for p_name, p_keys in self.perspective_servers.items() | ||||
|                 ], | ||||
|                 consumeErrors=True, | ||||
|             ).addErrback(unwrapFirstError) | ||||
|         ) | ||||
| 
 | ||||
|         union_of_keys = {} | ||||
|         for result in results: | ||||
|             for server_name, keys in result.items(): | ||||
|                 union_of_keys.setdefault(server_name, {}).update(keys) | ||||
| 
 | ||||
|         defer.returnValue(union_of_keys) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_keys_from_server(self, server_name_and_key_ids): | ||||
|         results = yield logcontext.make_deferred_yieldable( | ||||
|             defer.gatherResults( | ||||
|                 [ | ||||
|                     run_in_background( | ||||
|                         self.get_server_verify_key_v2_direct, server_name, key_ids | ||||
|                     ) | ||||
|                     for server_name, key_ids in server_name_and_key_ids | ||||
|                 ], | ||||
|                 consumeErrors=True, | ||||
|             ).addErrback(unwrapFirstError) | ||||
|         ) | ||||
| 
 | ||||
|         merged = {} | ||||
|         for result in results: | ||||
|             merged.update(result) | ||||
| 
 | ||||
|         defer.returnValue( | ||||
|             {server_name: keys for server_name, keys in merged.items() if keys} | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_server_verify_key_v2_indirect( | ||||
|         self, server_names_and_key_ids, perspective_name, perspective_keys | ||||
|     ): | ||||
|         """ | ||||
|         Args: | ||||
|             server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]): | ||||
|                 list of (server_name, iterable[key_id]) tuples to fetch keys for | ||||
|             perspective_name (str): name of the notary server to query for the keys | ||||
|             perspective_keys (dict[str, VerifyKey]): map of key_id->key for the | ||||
|                 notary server | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map | ||||
|                 from server_name -> key_id -> FetchKeyResult | ||||
|         """ | ||||
|         # TODO(mark): Set the minimum_valid_until_ts to that needed by | ||||
|         # the events being validated or the current time if validating | ||||
|         # an incoming request. | ||||
|         try: | ||||
|             query_response = yield self.client.post_json( | ||||
|                 destination=perspective_name, | ||||
|                 path="/_matrix/key/v2/query", | ||||
|                 data={ | ||||
|                     u"server_keys": { | ||||
|                         server_name: { | ||||
|                             key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids | ||||
|                         } | ||||
|                         for server_name, key_ids in server_names_and_key_ids | ||||
|                     } | ||||
|                 }, | ||||
|                 long_retries=True, | ||||
|             ) | ||||
|         except (NotRetryingDestination, RequestSendFailed) as e: | ||||
|             raise_from(KeyLookupError("Failed to connect to remote server"), e) | ||||
|         except HttpResponseException as e: | ||||
|             raise_from(KeyLookupError("Remote server returned an error"), e) | ||||
| 
 | ||||
|         keys = {} | ||||
|         added_keys = [] | ||||
| 
 | ||||
|         time_now_ms = self.clock.time_msec() | ||||
| 
 | ||||
|         for response in query_response["server_keys"]: | ||||
|             if ( | ||||
|                 u"signatures" not in response | ||||
|                 or perspective_name not in response[u"signatures"] | ||||
|             ): | ||||
|                 raise KeyLookupError( | ||||
|                     "Key response not signed by perspective server" | ||||
|                     " %r" % (perspective_name,) | ||||
|                 ) | ||||
| 
 | ||||
|             verified = False | ||||
|             for key_id in response[u"signatures"][perspective_name]: | ||||
|                 if key_id in perspective_keys: | ||||
|                     verify_signed_json( | ||||
|                         response, perspective_name, perspective_keys[key_id] | ||||
|                     ) | ||||
|                     verified = True | ||||
| 
 | ||||
|             if not verified: | ||||
|                 logging.info( | ||||
|                     "Response from perspective server %r not signed with a" | ||||
|                     " known key, signed with: %r, known keys: %r", | ||||
|                     perspective_name, | ||||
|                     list(response[u"signatures"][perspective_name]), | ||||
|                     list(perspective_keys), | ||||
|                 ) | ||||
|                 raise KeyLookupError( | ||||
|                     "Response not signed with a known key for perspective" | ||||
|                     " server %r" % (perspective_name,) | ||||
|                 ) | ||||
| 
 | ||||
|             processed_response = yield self.process_v2_response( | ||||
|                 perspective_name, response, time_added_ms=time_now_ms | ||||
|             ) | ||||
|             server_name = response["server_name"] | ||||
| 
 | ||||
|             added_keys.extend( | ||||
|                 (server_name, key_id, key) for key_id, key in processed_response.items() | ||||
|             ) | ||||
|             keys.setdefault(server_name, {}).update(processed_response) | ||||
| 
 | ||||
|         yield self.store.store_server_verify_keys( | ||||
|             perspective_name, time_now_ms, added_keys | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue(keys) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_server_verify_key_v2_direct(self, server_name, key_ids): | ||||
|         keys = {}  # type: dict[str, FetchKeyResult] | ||||
| 
 | ||||
|         for requested_key_id in key_ids: | ||||
|             if requested_key_id in keys: | ||||
|                 continue | ||||
| 
 | ||||
|             time_now_ms = self.clock.time_msec() | ||||
|             try: | ||||
|                 response = yield self.client.get_json( | ||||
|                     destination=server_name, | ||||
|                     path="/_matrix/key/v2/server/" | ||||
|                     + urllib.parse.quote(requested_key_id), | ||||
|                     ignore_backoff=True, | ||||
|                 ) | ||||
|             except (NotRetryingDestination, RequestSendFailed) as e: | ||||
|                 raise_from(KeyLookupError("Failed to connect to remote server"), e) | ||||
|             except HttpResponseException as e: | ||||
|                 raise_from(KeyLookupError("Remote server returned an error"), e) | ||||
| 
 | ||||
|             if ( | ||||
|                 u"signatures" not in response | ||||
|                 or server_name not in response[u"signatures"] | ||||
|             ): | ||||
|                 raise KeyLookupError("Key response not signed by remote server") | ||||
| 
 | ||||
|             if response["server_name"] != server_name: | ||||
|                 raise KeyLookupError( | ||||
|                     "Expected a response for server %r not %r" | ||||
|                     % (server_name, response["server_name"]) | ||||
|                 ) | ||||
| 
 | ||||
|             response_keys = yield self.process_v2_response( | ||||
|                 from_server=server_name, | ||||
|                 requested_ids=[requested_key_id], | ||||
|                 response_json=response, | ||||
|                 time_added_ms=time_now_ms, | ||||
|             ) | ||||
|             yield self.store.store_server_verify_keys( | ||||
|                 server_name, | ||||
|                 time_now_ms, | ||||
|                 ((server_name, key_id, key) for key_id, key in response_keys.items()), | ||||
|             ) | ||||
|             keys.update(response_keys) | ||||
| 
 | ||||
|         defer.returnValue({server_name: keys}) | ||||
| class BaseV2KeyFetcher(object): | ||||
|     def __init__(self, hs): | ||||
|         self.store = hs.get_datastore() | ||||
|         self.config = hs.get_config() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def process_v2_response( | ||||
|  | @ -670,6 +485,226 @@ class Keyring(object): | |||
|         defer.returnValue(verify_keys) | ||||
| 
 | ||||
| 
 | ||||
| class PerspectivesKeyFetcher(BaseV2KeyFetcher): | ||||
|     """KeyFetcher impl which fetches keys from the "perspectives" servers""" | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(PerspectivesKeyFetcher, self).__init__(hs) | ||||
|         self.clock = hs.get_clock() | ||||
|         self.client = hs.get_http_client() | ||||
|         self.perspective_servers = self.config.perspectives | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_keys(self, server_name_and_key_ids): | ||||
|         """see KeyFetcher.get_keys""" | ||||
| 
 | ||||
|         @defer.inlineCallbacks | ||||
|         def get_key(perspective_name, perspective_keys): | ||||
|             try: | ||||
|                 result = yield self.get_server_verify_key_v2_indirect( | ||||
|                     server_name_and_key_ids, perspective_name, perspective_keys | ||||
|                 ) | ||||
|                 defer.returnValue(result) | ||||
|             except KeyLookupError as e: | ||||
|                 logger.warning("Key lookup failed from %r: %s", perspective_name, e) | ||||
|             except Exception as e: | ||||
|                 logger.exception( | ||||
|                     "Unable to get key from %r: %s %s", | ||||
|                     perspective_name, | ||||
|                     type(e).__name__, | ||||
|                     str(e), | ||||
|                 ) | ||||
| 
 | ||||
|             defer.returnValue({}) | ||||
| 
 | ||||
|         results = yield logcontext.make_deferred_yieldable( | ||||
|             defer.gatherResults( | ||||
|                 [ | ||||
|                     run_in_background(get_key, p_name, p_keys) | ||||
|                     for p_name, p_keys in self.perspective_servers.items() | ||||
|                 ], | ||||
|                 consumeErrors=True, | ||||
|             ).addErrback(unwrapFirstError) | ||||
|         ) | ||||
| 
 | ||||
|         union_of_keys = {} | ||||
|         for result in results: | ||||
|             for server_name, keys in result.items(): | ||||
|                 union_of_keys.setdefault(server_name, {}).update(keys) | ||||
| 
 | ||||
|         defer.returnValue(union_of_keys) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_server_verify_key_v2_indirect( | ||||
|         self, server_names_and_key_ids, perspective_name, perspective_keys | ||||
|     ): | ||||
|         """ | ||||
|         Args: | ||||
|             server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]): | ||||
|                 list of (server_name, iterable[key_id]) tuples to fetch keys for | ||||
|             perspective_name (str): name of the notary server to query for the keys | ||||
|             perspective_keys (dict[str, VerifyKey]): map of key_id->key for the | ||||
|                 notary server | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map | ||||
|                 from server_name -> key_id -> FetchKeyResult | ||||
|         """ | ||||
|         # TODO(mark): Set the minimum_valid_until_ts to that needed by | ||||
|         # the events being validated or the current time if validating | ||||
|         # an incoming request. | ||||
|         try: | ||||
|             query_response = yield self.client.post_json( | ||||
|                 destination=perspective_name, | ||||
|                 path="/_matrix/key/v2/query", | ||||
|                 data={ | ||||
|                     u"server_keys": { | ||||
|                         server_name: { | ||||
|                             key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids | ||||
|                         } | ||||
|                         for server_name, key_ids in server_names_and_key_ids | ||||
|                     } | ||||
|                 }, | ||||
|                 long_retries=True, | ||||
|             ) | ||||
|         except (NotRetryingDestination, RequestSendFailed) as e: | ||||
|             raise_from(KeyLookupError("Failed to connect to remote server"), e) | ||||
|         except HttpResponseException as e: | ||||
|             raise_from(KeyLookupError("Remote server returned an error"), e) | ||||
| 
 | ||||
|         keys = {} | ||||
|         added_keys = [] | ||||
| 
 | ||||
|         time_now_ms = self.clock.time_msec() | ||||
| 
 | ||||
|         for response in query_response["server_keys"]: | ||||
|             if ( | ||||
|                 u"signatures" not in response | ||||
|                 or perspective_name not in response[u"signatures"] | ||||
|             ): | ||||
|                 raise KeyLookupError( | ||||
|                     "Key response not signed by perspective server" | ||||
|                     " %r" % (perspective_name,) | ||||
|                 ) | ||||
| 
 | ||||
|             verified = False | ||||
|             for key_id in response[u"signatures"][perspective_name]: | ||||
|                 if key_id in perspective_keys: | ||||
|                     verify_signed_json( | ||||
|                         response, perspective_name, perspective_keys[key_id] | ||||
|                     ) | ||||
|                     verified = True | ||||
| 
 | ||||
|             if not verified: | ||||
|                 logging.info( | ||||
|                     "Response from perspective server %r not signed with a" | ||||
|                     " known key, signed with: %r, known keys: %r", | ||||
|                     perspective_name, | ||||
|                     list(response[u"signatures"][perspective_name]), | ||||
|                     list(perspective_keys), | ||||
|                 ) | ||||
|                 raise KeyLookupError( | ||||
|                     "Response not signed with a known key for perspective" | ||||
|                     " server %r" % (perspective_name,) | ||||
|                 ) | ||||
| 
 | ||||
|             processed_response = yield self.process_v2_response( | ||||
|                 perspective_name, response, time_added_ms=time_now_ms | ||||
|             ) | ||||
|             server_name = response["server_name"] | ||||
| 
 | ||||
|             added_keys.extend( | ||||
|                 (server_name, key_id, key) for key_id, key in processed_response.items() | ||||
|             ) | ||||
|             keys.setdefault(server_name, {}).update(processed_response) | ||||
| 
 | ||||
|         yield self.store.store_server_verify_keys( | ||||
|             perspective_name, time_now_ms, added_keys | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue(keys) | ||||
| 
 | ||||
| 
 | ||||
| class ServerKeyFetcher(BaseV2KeyFetcher): | ||||
|     """KeyFetcher impl which fetches keys from the origin servers""" | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(ServerKeyFetcher, self).__init__(hs) | ||||
|         self.clock = hs.get_clock() | ||||
|         self.client = hs.get_http_client() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_keys(self, server_name_and_key_ids): | ||||
|         """see KeyFetcher.get_keys""" | ||||
|         results = yield logcontext.make_deferred_yieldable( | ||||
|             defer.gatherResults( | ||||
|                 [ | ||||
|                     run_in_background( | ||||
|                         self.get_server_verify_key_v2_direct, server_name, key_ids | ||||
|                     ) | ||||
|                     for server_name, key_ids in server_name_and_key_ids | ||||
|                 ], | ||||
|                 consumeErrors=True, | ||||
|             ).addErrback(unwrapFirstError) | ||||
|         ) | ||||
| 
 | ||||
|         merged = {} | ||||
|         for result in results: | ||||
|             merged.update(result) | ||||
| 
 | ||||
|         defer.returnValue( | ||||
|             {server_name: keys for server_name, keys in merged.items() if keys} | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_server_verify_key_v2_direct(self, server_name, key_ids): | ||||
|         keys = {}  # type: dict[str, FetchKeyResult] | ||||
| 
 | ||||
|         for requested_key_id in key_ids: | ||||
|             if requested_key_id in keys: | ||||
|                 continue | ||||
| 
 | ||||
|             time_now_ms = self.clock.time_msec() | ||||
|             try: | ||||
|                 response = yield self.client.get_json( | ||||
|                     destination=server_name, | ||||
|                     path="/_matrix/key/v2/server/" | ||||
|                     + urllib.parse.quote(requested_key_id), | ||||
|                     ignore_backoff=True, | ||||
|                 ) | ||||
|             except (NotRetryingDestination, RequestSendFailed) as e: | ||||
|                 raise_from(KeyLookupError("Failed to connect to remote server"), e) | ||||
|             except HttpResponseException as e: | ||||
|                 raise_from(KeyLookupError("Remote server returned an error"), e) | ||||
| 
 | ||||
|             if ( | ||||
|                 u"signatures" not in response | ||||
|                 or server_name not in response[u"signatures"] | ||||
|             ): | ||||
|                 raise KeyLookupError("Key response not signed by remote server") | ||||
| 
 | ||||
|             if response["server_name"] != server_name: | ||||
|                 raise KeyLookupError( | ||||
|                     "Expected a response for server %r not %r" | ||||
|                     % (server_name, response["server_name"]) | ||||
|                 ) | ||||
| 
 | ||||
|             response_keys = yield self.process_v2_response( | ||||
|                 from_server=server_name, | ||||
|                 requested_ids=[requested_key_id], | ||||
|                 response_json=response, | ||||
|                 time_added_ms=time_now_ms, | ||||
|             ) | ||||
|             yield self.store.store_server_verify_keys( | ||||
|                 server_name, | ||||
|                 time_now_ms, | ||||
|                 ((server_name, key_id, key) for key_id, key in response_keys.items()), | ||||
|             ) | ||||
|             keys.update(response_keys) | ||||
| 
 | ||||
|         defer.returnValue({server_name: keys}) | ||||
| 
 | ||||
| 
 | ||||
| @defer.inlineCallbacks | ||||
| def _handle_key_deferred(verify_request): | ||||
|     """Waits for the key to become available, and then performs a verification | ||||
|  |  | |||
|  | @ -24,7 +24,11 @@ from twisted.internet import defer | |||
| 
 | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.crypto import keyring | ||||
| from synapse.crypto.keyring import KeyLookupError | ||||
| from synapse.crypto.keyring import ( | ||||
|     KeyLookupError, | ||||
|     PerspectivesKeyFetcher, | ||||
|     ServerKeyFetcher, | ||||
| ) | ||||
| from synapse.storage.keys import FetchKeyResult | ||||
| from synapse.util import logcontext | ||||
| from synapse.util.logcontext import LoggingContext | ||||
|  | @ -218,12 +222,19 @@ class KeyringTestCase(unittest.HomeserverTestCase): | |||
|         self.assertFalse(d.called) | ||||
|         self.get_success(d) | ||||
| 
 | ||||
| 
 | ||||
| class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): | ||||
|     def make_homeserver(self, reactor, clock): | ||||
|         self.http_client = Mock() | ||||
|         hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client) | ||||
|         return hs | ||||
| 
 | ||||
|     def test_get_keys_from_server(self): | ||||
|         # arbitrarily advance the clock a bit | ||||
|         self.reactor.advance(100) | ||||
| 
 | ||||
|         SERVER_NAME = "server2" | ||||
|         kr = keyring.Keyring(self.hs) | ||||
|         fetcher = ServerKeyFetcher(self.hs) | ||||
|         testkey = signedjson.key.generate_signing_key("ver1") | ||||
|         testverifykey = signedjson.key.get_verify_key(testkey) | ||||
|         testverifykey_id = "ed25519:ver1" | ||||
|  | @ -250,7 +261,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): | |||
|         self.http_client.get_json.side_effect = get_json | ||||
| 
 | ||||
|         server_name_and_key_ids = [(SERVER_NAME, ("key1",))] | ||||
|         keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids)) | ||||
|         keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) | ||||
|         k = keys[SERVER_NAME][testverifykey_id] | ||||
|         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) | ||||
|         self.assertEqual(k.verify_key, testverifykey) | ||||
|  | @ -278,15 +289,26 @@ class KeyringTestCase(unittest.HomeserverTestCase): | |||
|         # change the server name: it should cause a rejection | ||||
|         response["server_name"] = "OTHER_SERVER" | ||||
|         self.get_failure( | ||||
|             kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError | ||||
|             fetcher.get_keys(server_name_and_key_ids), KeyLookupError | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): | ||||
|     def make_homeserver(self, reactor, clock): | ||||
|         self.mock_perspective_server = MockPerspectiveServer() | ||||
|         self.http_client = Mock() | ||||
|         hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client) | ||||
|         keys = self.mock_perspective_server.get_verify_keys() | ||||
|         hs.config.perspectives = {self.mock_perspective_server.server_name: keys} | ||||
|         return hs | ||||
| 
 | ||||
|     def test_get_keys_from_perspectives(self): | ||||
|         # arbitrarily advance the clock a bit | ||||
|         self.reactor.advance(100) | ||||
| 
 | ||||
|         fetcher = PerspectivesKeyFetcher(self.hs) | ||||
| 
 | ||||
|         SERVER_NAME = "server2" | ||||
|         kr = keyring.Keyring(self.hs) | ||||
|         testkey = signedjson.key.generate_signing_key("ver1") | ||||
|         testverifykey = signedjson.key.get_verify_key(testkey) | ||||
|         testverifykey_id = "ed25519:ver1" | ||||
|  | @ -320,7 +342,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): | |||
|         self.http_client.post_json.side_effect = post_json | ||||
| 
 | ||||
|         server_name_and_key_ids = [(SERVER_NAME, ("key1",))] | ||||
|         keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids)) | ||||
|         keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) | ||||
|         self.assertIn(SERVER_NAME, keys) | ||||
|         k = keys[SERVER_NAME][testverifykey_id] | ||||
|         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Richard van der Hoff
						Richard van der Hoff