Merge pull request #957 from matrix-org/markjh/verify
Clean up verify_json_objects_for_serverpull/958/head
						commit
						c38b7c4104
					
				| 
						 | 
				
			
			@ -44,7 +44,21 @@ import logging
 | 
			
		|||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
 | 
			
		||||
VerifyKeyRequest = namedtuple("VerifyRequest", (
 | 
			
		||||
    "server_name", "key_ids", "json_object", "deferred"
 | 
			
		||||
))
 | 
			
		||||
"""
 | 
			
		||||
A request for a verify key to verify a JSON object.
 | 
			
		||||
 | 
			
		||||
Attributes:
 | 
			
		||||
    server_name(str): The name of the server to verify against.
 | 
			
		||||
    key_ids(set(str)): The set of key_ids to that could be used to verify the
 | 
			
		||||
        JSON object
 | 
			
		||||
    json_object(dict): The JSON object to verify.
 | 
			
		||||
    deferred(twisted.internet.defer.Deferred):
 | 
			
		||||
        A deferred (server_name, key_id, verify_key) tuple that resolves when
 | 
			
		||||
        a verify key has been fetched
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Keyring(object):
 | 
			
		||||
| 
						 | 
				
			
			@ -74,39 +88,32 @@ class Keyring(object):
 | 
			
		|||
            list of deferreds indicating success or failure to verify each
 | 
			
		||||
            json object's signature for the given server_name.
 | 
			
		||||
        """
 | 
			
		||||
        group_id_to_json = {}
 | 
			
		||||
        group_id_to_group = {}
 | 
			
		||||
        group_ids = []
 | 
			
		||||
 | 
			
		||||
        next_group_id = 0
 | 
			
		||||
        deferreds = {}
 | 
			
		||||
        verify_requests = []
 | 
			
		||||
 | 
			
		||||
        for server_name, json_object in server_and_json:
 | 
			
		||||
            logger.debug("Verifying for %s", server_name)
 | 
			
		||||
            group_id = next_group_id
 | 
			
		||||
            next_group_id += 1
 | 
			
		||||
            group_ids.append(group_id)
 | 
			
		||||
 | 
			
		||||
            key_ids = signature_ids(json_object, server_name)
 | 
			
		||||
            if not key_ids:
 | 
			
		||||
                deferreds[group_id] = defer.fail(SynapseError(
 | 
			
		||||
                deferred = defer.fail(SynapseError(
 | 
			
		||||
                    400,
 | 
			
		||||
                    "Not signed with a supported algorithm",
 | 
			
		||||
                    Codes.UNAUTHORIZED,
 | 
			
		||||
                ))
 | 
			
		||||
            else:
 | 
			
		||||
                deferreds[group_id] = defer.Deferred()
 | 
			
		||||
                deferred = defer.Deferred()
 | 
			
		||||
 | 
			
		||||
            group = KeyGroup(server_name, group_id, key_ids)
 | 
			
		||||
            verify_request = VerifyKeyRequest(
 | 
			
		||||
                server_name, key_ids, json_object, deferred
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            group_id_to_group[group_id] = group
 | 
			
		||||
            group_id_to_json[group_id] = json_object
 | 
			
		||||
            verify_requests.append(verify_request)
 | 
			
		||||
 | 
			
		||||
        @defer.inlineCallbacks
 | 
			
		||||
        def handle_key_deferred(group, deferred):
 | 
			
		||||
            server_name = group.server_name
 | 
			
		||||
        def handle_key_deferred(verify_request):
 | 
			
		||||
            server_name = verify_request.server_name
 | 
			
		||||
            try:
 | 
			
		||||
                _, _, key_id, verify_key = yield deferred
 | 
			
		||||
                _, key_id, verify_key = yield verify_request.deferred
 | 
			
		||||
            except IOError as e:
 | 
			
		||||
                logger.warn(
 | 
			
		||||
                    "Got IOError when downloading keys for %s: %s %s",
 | 
			
		||||
| 
						 | 
				
			
			@ -128,7 +135,7 @@ class Keyring(object):
 | 
			
		|||
                    Codes.UNAUTHORIZED,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            json_object = group_id_to_json[group.group_id]
 | 
			
		||||
            json_object = verify_request.json_object
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                verify_signed_json(json_object, server_name, verify_key)
 | 
			
		||||
| 
						 | 
				
			
			@ -157,36 +164,34 @@ class Keyring(object):
 | 
			
		|||
 | 
			
		||||
            # Actually start fetching keys.
 | 
			
		||||
            wait_on_deferred.addBoth(
 | 
			
		||||
                lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
 | 
			
		||||
                lambda _: self.get_server_verify_keys(verify_requests)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # When we've finished fetching all the keys for a given server_name,
 | 
			
		||||
            # resolve the deferred passed to `wait_for_previous_lookups` so that
 | 
			
		||||
            # any lookups waiting will proceed.
 | 
			
		||||
            server_to_gids = {}
 | 
			
		||||
            server_to_request_ids = {}
 | 
			
		||||
 | 
			
		||||
            def remove_deferreds(res, server_name, group_id):
 | 
			
		||||
                server_to_gids[server_name].discard(group_id)
 | 
			
		||||
                if not server_to_gids[server_name]:
 | 
			
		||||
            def remove_deferreds(res, server_name, verify_request):
 | 
			
		||||
                request_id = id(verify_request)
 | 
			
		||||
                server_to_request_ids[server_name].discard(request_id)
 | 
			
		||||
                if not server_to_request_ids[server_name]:
 | 
			
		||||
                    d = server_to_deferred.pop(server_name, None)
 | 
			
		||||
                    if d:
 | 
			
		||||
                        d.callback(None)
 | 
			
		||||
                return res
 | 
			
		||||
 | 
			
		||||
            for g_id, deferred in deferreds.items():
 | 
			
		||||
                server_name = group_id_to_group[g_id].server_name
 | 
			
		||||
                server_to_gids.setdefault(server_name, set()).add(g_id)
 | 
			
		||||
                deferred.addBoth(remove_deferreds, server_name, g_id)
 | 
			
		||||
            for verify_request in verify_requests:
 | 
			
		||||
                server_name = verify_request.server_name
 | 
			
		||||
                request_id = id(verify_request)
 | 
			
		||||
                server_to_request_ids.setdefault(server_name, set()).add(request_id)
 | 
			
		||||
                deferred.addBoth(remove_deferreds, server_name, verify_request)
 | 
			
		||||
 | 
			
		||||
        # Pass those keys to handle_key_deferred so that the json object
 | 
			
		||||
        # signatures can be verified
 | 
			
		||||
        return [
 | 
			
		||||
            preserve_context_over_fn(
 | 
			
		||||
                handle_key_deferred,
 | 
			
		||||
                group_id_to_group[g_id],
 | 
			
		||||
                deferreds[g_id],
 | 
			
		||||
            )
 | 
			
		||||
            for g_id in group_ids
 | 
			
		||||
            preserve_context_over_fn(handle_key_deferred, verify_request)
 | 
			
		||||
            for verify_request in verify_requests
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
| 
						 | 
				
			
			@ -220,7 +225,7 @@ class Keyring(object):
 | 
			
		|||
 | 
			
		||||
            d.addBoth(rm, server_name)
 | 
			
		||||
 | 
			
		||||
    def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
 | 
			
		||||
    def get_server_verify_keys(self, verify_requests):
 | 
			
		||||
        """Takes a dict of KeyGroups and tries to find at least one key for
 | 
			
		||||
        each group.
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -237,63 +242,64 @@ class Keyring(object):
 | 
			
		|||
            merged_results = {}
 | 
			
		||||
 | 
			
		||||
            missing_keys = {}
 | 
			
		||||
            for group in group_id_to_group.values():
 | 
			
		||||
                missing_keys.setdefault(group.server_name, set()).update(
 | 
			
		||||
                    group.key_ids
 | 
			
		||||
            for verify_request in verify_requests:
 | 
			
		||||
                missing_keys.setdefault(verify_request.server_name, set()).update(
 | 
			
		||||
                    verify_request.key_ids
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            for fn in key_fetch_fns:
 | 
			
		||||
                results = yield fn(missing_keys.items())
 | 
			
		||||
                merged_results.update(results)
 | 
			
		||||
 | 
			
		||||
                # We now need to figure out which groups we have keys for
 | 
			
		||||
                # and which we don't
 | 
			
		||||
                missing_groups = {}
 | 
			
		||||
                for group in group_id_to_group.values():
 | 
			
		||||
                    for key_id in group.key_ids:
 | 
			
		||||
                        if key_id in merged_results[group.server_name]:
 | 
			
		||||
                # We now need to figure out which verify requests we have keys
 | 
			
		||||
                # for and which we don't
 | 
			
		||||
                missing_keys = {}
 | 
			
		||||
                requests_missing_keys = []
 | 
			
		||||
                for verify_request in verify_requests:
 | 
			
		||||
                    server_name = verify_request.server_name
 | 
			
		||||
                    result_keys = merged_results[server_name]
 | 
			
		||||
 | 
			
		||||
                    if verify_request.deferred.called:
 | 
			
		||||
                        # We've already called this deferred, which probably
 | 
			
		||||
                        # means that we've already found a key for it.
 | 
			
		||||
                        continue
 | 
			
		||||
 | 
			
		||||
                    for key_id in verify_request.key_ids:
 | 
			
		||||
                        if key_id in result_keys:
 | 
			
		||||
                            with PreserveLoggingContext():
 | 
			
		||||
                                group_id_to_deferred[group.group_id].callback((
 | 
			
		||||
                                    group.group_id,
 | 
			
		||||
                                    group.server_name,
 | 
			
		||||
                                verify_request.deferred.callback((
 | 
			
		||||
                                    server_name,
 | 
			
		||||
                                    key_id,
 | 
			
		||||
                                    merged_results[group.server_name][key_id],
 | 
			
		||||
                                    result_keys[key_id],
 | 
			
		||||
                                ))
 | 
			
		||||
                            break
 | 
			
		||||
                    else:
 | 
			
		||||
                        missing_groups.setdefault(
 | 
			
		||||
                            group.server_name, []
 | 
			
		||||
                        ).append(group)
 | 
			
		||||
                        # The else block is only reached if the loop above
 | 
			
		||||
                        # doesn't break.
 | 
			
		||||
                        missing_keys.setdefault(server_name, set()).update(
 | 
			
		||||
                            verify_request.key_ids
 | 
			
		||||
                        )
 | 
			
		||||
                        requests_missing_keys.append(verify_request)
 | 
			
		||||
 | 
			
		||||
                if not missing_groups:
 | 
			
		||||
                if not missing_keys:
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                missing_keys = {
 | 
			
		||||
                    server_name: set(
 | 
			
		||||
                        key_id for group in groups for key_id in group.key_ids
 | 
			
		||||
                    )
 | 
			
		||||
                    for server_name, groups in missing_groups.items()
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
            for groups in missing_groups.values():
 | 
			
		||||
                for group in groups:
 | 
			
		||||
                    group_id_to_deferred[group.group_id].errback(SynapseError(
 | 
			
		||||
            for verify_request in requests_missing_keys.values():
 | 
			
		||||
                verify_request.deferred.errback(SynapseError(
 | 
			
		||||
                    401,
 | 
			
		||||
                    "No key for %s with id %s" % (
 | 
			
		||||
                            group.server_name, group.key_ids,
 | 
			
		||||
                        verify_request.server_name, verify_request.key_ids,
 | 
			
		||||
                    ),
 | 
			
		||||
                    Codes.UNAUTHORIZED,
 | 
			
		||||
                ))
 | 
			
		||||
 | 
			
		||||
        def on_err(err):
 | 
			
		||||
            for deferred in group_id_to_deferred.values():
 | 
			
		||||
                if not deferred.called:
 | 
			
		||||
                    deferred.errback(err)
 | 
			
		||||
            for verify_request in verify_requests:
 | 
			
		||||
                if not verify_request.deferred.called:
 | 
			
		||||
                    verify_request.deferred.errback(err)
 | 
			
		||||
 | 
			
		||||
        do_iterations().addErrback(on_err)
 | 
			
		||||
 | 
			
		||||
        return group_id_to_deferred
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def get_keys_from_store(self, server_name_and_key_ids):
 | 
			
		||||
        res = yield defer.gatherResults(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue