Factor key retrieval out into a separate function

pull/7289/head
Andrew Morgan 2020-04-17 12:07:19 +01:00
parent 2d88b5d39d
commit f41730078e
1 changed files with 61 additions and 43 deletions

View File

@ -16,6 +16,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, Optional
from six import iteritems from six import iteritems
@ -962,7 +963,7 @@ class E2eKeysHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_e2e_cross_signing_verify_key( def _get_e2e_cross_signing_verify_key(
self, user_id: str, desired_key_type: str, from_user_id: str = None self, user_id: str, key_type: str, from_user_id: str = None
): ):
"""Fetch or request the given cross-signing public key. """Fetch or request the given cross-signing public key.
@ -972,7 +973,7 @@ class E2eKeysHandler(object):
Args: Args:
user_id: the user whose key should be fetched user_id: the user whose key should be fetched
desired_key_type: the type of key to fetch key_type: the type of key to fetch
from_user_id: the user that we are fetching the keys for. from_user_id: the user that we are fetching the keys for.
This affects what signatures are fetched. This affects what signatures are fetched.
@ -986,7 +987,7 @@ class E2eKeysHandler(object):
""" """
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
key = yield self.store.get_e2e_cross_signing_key( key = yield self.store.get_e2e_cross_signing_key(
user_id, desired_key_type, from_user_id user_id, key_type, from_user_id
) )
# If we still can't find the key, and we're looking for keys of another user # If we still can't find the key, and we're looking for keys of another user
@ -996,30 +997,65 @@ class E2eKeysHandler(object):
# cross-sign a remote user, but does not share any rooms with them yet. # cross-sign a remote user, but does not share any rooms with them yet.
# Thus, we would not have their key list yet. We fetch the key here and # Thus, we would not have their key list yet. We fetch the key here and
# store it just in case. # store it just in case.
supported_remote_key_types = ["master", "self_signing"]
if ( if (
key is None key is None
and not self.is_mine(user) and not self.is_mine(user)
# We only get "master" and "self_signing" keys from remote servers # We only get "master" and "self_signing" keys from remote servers
and desired_key_type in supported_remote_key_types and key_type in ["master", "self_signing"]
): ):
remote_result = None key = yield self._retrieve_cross_signing_keys_for_remote_user(
user, key_type
)
if key is None:
logger.debug("No %s key found for %s", key_type, user_id)
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
try:
key_id, verify_key = get_verify_key_from_cross_signing_key(key)
except ValueError as e:
logger.debug(
"Invalid %s key retrieved: %s - %s %s", key_type, key, type(e), e,
)
raise SynapseError(
502, "Invalid %s key retrieved from remote server", key_type
)
return key, key_id, verify_key
@defer.inlineCallbacks
def _retrieve_cross_signing_keys_for_remote_user(
self, user: UserID, desired_key_type: str,
) -> Optional[Dict]:
"""Queries cross-signing keys for a remote user and saves them to the database
Only the key specified by `key_type` will be returned, while all retrieved keys
will be saved regardless
Args:
user: The user to query remote keys for
desired_key_type: The type of key to receive. One of "master", "self_signing"
Returns:
The retrieved key content, or None if the key could not be retrieved
"""
try: try:
remote_result = yield self.federation.query_user_devices( remote_result = yield self.federation.query_user_devices(
user.domain, user_id user.domain, user.to_string()
) )
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Unable to query %s for cross-signing keys of user %s: %s %s", "Unable to query %s for cross-signing keys of user %s: %s %s",
user.domain, user.domain,
user_id, user.to_string(),
type(e), type(e),
e, e,
) )
return None
if remote_result is not None:
# Process each of the retrieved cross-signing keys # Process each of the retrieved cross-signing keys
for key_type in supported_remote_key_types: key = None
for key_type in ["master", "self_signing"]:
key_content = remote_result.get(key_type + "_key") key_content = remote_result.get(key_type + "_key")
if not key_content: if not key_content:
continue continue
@ -1031,28 +1067,10 @@ class E2eKeysHandler(object):
# At the same time, store this key in the db for # At the same time, store this key in the db for
# subsequent queries # subsequent queries
yield self.store.set_e2e_cross_signing_key( yield self.store.set_e2e_cross_signing_key(
user_id, key_type, key_content user.to_string(), key_type, key_content
) )
if key is None: return key
logger.debug("No %s key found for %s", desired_key_type, user_id)
raise NotFoundError("No %s key found for %s" % (desired_key_type, user_id))
try:
key_id, verify_key = get_verify_key_from_cross_signing_key(key)
except ValueError as e:
logger.debug(
"Invalid %s key retrieved: %s - %s %s",
desired_key_type,
key,
type(e),
e,
)
raise SynapseError(
502, "Invalid %s key retrieved from remote server", desired_key_type
)
return key, key_id, verify_key
def _check_cross_signing_key(key, user_id, key_type, signing_key=None): def _check_cross_signing_key(key, user_id, key_type, signing_key=None):