Merge pull request #273 from matrix-org/erikj/key_fetch_fix
Various bug fixes to crypto.keyringreviewable/pr275/r1
						commit
						a5b41b809f
					
				| 
						 | 
					@ -162,7 +162,9 @@ class Keyring(object):
 | 
				
			||||||
        def remove_deferreds(res, server_name, group_id):
 | 
					        def remove_deferreds(res, server_name, group_id):
 | 
				
			||||||
            server_to_gids[server_name].discard(group_id)
 | 
					            server_to_gids[server_name].discard(group_id)
 | 
				
			||||||
            if not server_to_gids[server_name]:
 | 
					            if not server_to_gids[server_name]:
 | 
				
			||||||
                server_to_deferred.pop(server_name).callback(None)
 | 
					                d = server_to_deferred.pop(server_name, None)
 | 
				
			||||||
 | 
					                if d:
 | 
				
			||||||
 | 
					                    d.callback(None)
 | 
				
			||||||
            return res
 | 
					            return res
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for g_id, deferred in deferreds.items():
 | 
					        for g_id, deferred in deferreds.items():
 | 
				
			||||||
| 
						 | 
					@ -200,8 +202,15 @@ class Keyring(object):
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for server_name, deferred in server_to_deferred:
 | 
					        for server_name, deferred in server_to_deferred.items():
 | 
				
			||||||
            self.key_downloads[server_name] = ObservableDeferred(deferred)
 | 
					            d = ObservableDeferred(deferred)
 | 
				
			||||||
 | 
					            self.key_downloads[server_name] = d
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            def rm(r, server_name):
 | 
				
			||||||
 | 
					                self.key_downloads.pop(server_name, None)
 | 
				
			||||||
 | 
					                return r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            d.addBoth(rm, server_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
 | 
					    def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
 | 
				
			||||||
        """Takes a dict of KeyGroups and tries to find at least one key for
 | 
					        """Takes a dict of KeyGroups and tries to find at least one key for
 | 
				
			||||||
| 
						 | 
					@ -220,9 +229,8 @@ class Keyring(object):
 | 
				
			||||||
            merged_results = {}
 | 
					            merged_results = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            missing_keys = {
 | 
					            missing_keys = {
 | 
				
			||||||
                group.server_name: key_id
 | 
					                group.server_name: set(group.key_ids)
 | 
				
			||||||
                for group in group_id_to_group.values()
 | 
					                for group in group_id_to_group.values()
 | 
				
			||||||
                for key_id in group.key_ids
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for fn in key_fetch_fns:
 | 
					            for fn in key_fetch_fns:
 | 
				
			||||||
| 
						 | 
					@ -279,16 +287,15 @@ class Keyring(object):
 | 
				
			||||||
    def get_keys_from_store(self, server_name_and_key_ids):
 | 
					    def get_keys_from_store(self, server_name_and_key_ids):
 | 
				
			||||||
        res = yield defer.gatherResults(
 | 
					        res = yield defer.gatherResults(
 | 
				
			||||||
            [
 | 
					            [
 | 
				
			||||||
                self.store.get_server_verify_keys(server_name, key_ids)
 | 
					                self.store.get_server_verify_keys(
 | 
				
			||||||
 | 
					                    server_name, key_ids
 | 
				
			||||||
 | 
					                ).addCallback(lambda ks, server: (server, ks), server_name)
 | 
				
			||||||
                for server_name, key_ids in server_name_and_key_ids
 | 
					                for server_name, key_ids in server_name_and_key_ids
 | 
				
			||||||
            ],
 | 
					            ],
 | 
				
			||||||
            consumeErrors=True,
 | 
					            consumeErrors=True,
 | 
				
			||||||
        ).addErrback(unwrapFirstError)
 | 
					        ).addErrback(unwrapFirstError)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        defer.returnValue(dict(zip(
 | 
					        defer.returnValue(dict(res))
 | 
				
			||||||
            [server_name for server_name, _ in server_name_and_key_ids],
 | 
					 | 
				
			||||||
            res
 | 
					 | 
				
			||||||
        )))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @defer.inlineCallbacks
 | 
					    @defer.inlineCallbacks
 | 
				
			||||||
    def get_keys_from_perspectives(self, server_name_and_key_ids):
 | 
					    def get_keys_from_perspectives(self, server_name_and_key_ids):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue