Claim fallback keys in bulk (#16570)
							parent
							
								
									a3f6200d65
								
							
						
					
					
						commit
						fdce83ee60
					
				| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Improve the performance of claiming encryption keys.
 | 
			
		||||
| 
						 | 
				
			
			@ -659,6 +659,20 @@ class E2eKeysHandler:
 | 
			
		|||
        timeout: Optional[int],
 | 
			
		||||
        always_include_fallback_keys: bool,
 | 
			
		||||
    ) -> JsonDict:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            query: A chain of maps from (user_id, device_id, algorithm) to the requested
 | 
			
		||||
                number of keys to claim.
 | 
			
		||||
            user: The user who is claiming these keys.
 | 
			
		||||
            timeout: How long to wait for any federation key claim requests before
 | 
			
		||||
                giving up.
 | 
			
		||||
            always_include_fallback_keys: always include a fallback key for local users'
 | 
			
		||||
                devices, even if we managed to claim a one-time-key.
 | 
			
		||||
 | 
			
		||||
        Returns: a heterogeneous dict with two keys:
 | 
			
		||||
            one_time_keys: chain of maps user ID -> device ID -> key ID -> key.
 | 
			
		||||
            failures: map from remote destination to a JsonDict describing the error.
 | 
			
		||||
        """
 | 
			
		||||
        local_query: List[Tuple[str, str, str, int]] = []
 | 
			
		||||
        remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -420,6 +420,16 @@ class LoggingTransaction:
 | 
			
		|||
        self._do_execute(self.txn.execute, sql, parameters)
 | 
			
		||||
 | 
			
		||||
    def executemany(self, sql: str, *args: Any) -> None:
 | 
			
		||||
        """Repeatedly execute the same piece of SQL with different parameters.
 | 
			
		||||
 | 
			
		||||
        See https://peps.python.org/pep-0249/#executemany. Note in particular that
 | 
			
		||||
 | 
			
		||||
        > Use of this method for an operation which produces one or more result sets
 | 
			
		||||
        > constitutes undefined behavior
 | 
			
		||||
 | 
			
		||||
        so you can't use this for e.g. a SELECT, an UPDATE ... RETURNING, or a
 | 
			
		||||
        DELETE FROM... RETURNING.
 | 
			
		||||
        """
 | 
			
		||||
        # TODO: we should add a type for *args here. Looking at Cursor.executemany
 | 
			
		||||
        # and DBAPI2 it ought to be Sequence[_Parameter], but we pass in
 | 
			
		||||
        # Iterable[Iterable[Any]] in execute_batch and execute_values above, which mypy
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,6 +24,7 @@ from typing import (
 | 
			
		|||
    Mapping,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Set,
 | 
			
		||||
    Tuple,
 | 
			
		||||
    Union,
 | 
			
		||||
    cast,
 | 
			
		||||
| 
						 | 
				
			
			@ -1260,6 +1261,65 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 | 
			
		|||
        Returns:
 | 
			
		||||
            A map of user ID -> a map device ID -> a map of key ID -> JSON.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(self.database_engine, PostgresEngine):
 | 
			
		||||
            return await self.db_pool.runInteraction(
 | 
			
		||||
                "_claim_e2e_fallback_keys_bulk",
 | 
			
		||||
                self._claim_e2e_fallback_keys_bulk_txn,
 | 
			
		||||
                query_list,
 | 
			
		||||
                db_autocommit=True,
 | 
			
		||||
            )
 | 
			
		||||
            # Use an UPDATE FROM... RETURNING combined with a VALUES block to do
 | 
			
		||||
            # everything in one query. Note: this is also supported in SQLite 3.33.0,
 | 
			
		||||
            # (see https://www.sqlite.org/lang_update.html#update_from), but we do not
 | 
			
		||||
            # have an equivalent of psycopg2's execute_values to do this in one query.
 | 
			
		||||
        else:
 | 
			
		||||
            return await self._claim_e2e_fallback_keys_simple(query_list)
 | 
			
		||||
 | 
			
		||||
    def _claim_e2e_fallback_keys_bulk_txn(
 | 
			
		||||
        self,
 | 
			
		||||
        txn: LoggingTransaction,
 | 
			
		||||
        query_list: Iterable[Tuple[str, str, str, bool]],
 | 
			
		||||
    ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
 | 
			
		||||
        """Efficient implementation of claim_e2e_fallback_keys for Postgres.
 | 
			
		||||
 | 
			
		||||
        Safe to autocommit: this is a single query.
 | 
			
		||||
        """
 | 
			
		||||
        results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
 | 
			
		||||
 | 
			
		||||
        sql = """
 | 
			
		||||
            WITH claims(user_id, device_id, algorithm, mark_as_used) AS (
 | 
			
		||||
                VALUES ?
 | 
			
		||||
            )
 | 
			
		||||
            UPDATE e2e_fallback_keys_json k
 | 
			
		||||
            SET used = used OR mark_as_used
 | 
			
		||||
            FROM claims
 | 
			
		||||
            WHERE (k.user_id, k.device_id, k.algorithm) = (claims.user_id, claims.device_id, claims.algorithm)
 | 
			
		||||
            RETURNING k.user_id, k.device_id, k.algorithm, k.key_id, k.key_json;
 | 
			
		||||
        """
 | 
			
		||||
        claimed_keys = cast(
 | 
			
		||||
            List[Tuple[str, str, str, str, str]],
 | 
			
		||||
            txn.execute_values(sql, query_list),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        seen_user_device: Set[Tuple[str, str]] = set()
 | 
			
		||||
        for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
 | 
			
		||||
            device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
 | 
			
		||||
            device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
 | 
			
		||||
 | 
			
		||||
            if (user_id, device_id) in seen_user_device:
 | 
			
		||||
                continue
 | 
			
		||||
            seen_user_device.add((user_id, device_id))
 | 
			
		||||
            self._invalidate_cache_and_stream(
 | 
			
		||||
                txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    async def _claim_e2e_fallback_keys_simple(
 | 
			
		||||
        self,
 | 
			
		||||
        query_list: Iterable[Tuple[str, str, str, bool]],
 | 
			
		||||
    ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
 | 
			
		||||
        """Naive, inefficient implementation of claim_e2e_fallback_keys for SQLite."""
 | 
			
		||||
        results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
 | 
			
		||||
        for user_id, device_id, algorithm, mark_as_used in query_list:
 | 
			
		||||
            row = await self.db_pool.simple_select_one(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -322,6 +322,83 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 | 
			
		|||
            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_fallback_key_bulk(self) -> None:
 | 
			
		||||
        """Like test_fallback_key, but claims multiple keys in one handler call."""
 | 
			
		||||
        alice = f"@alice:{self.hs.hostname}"
 | 
			
		||||
        brian = f"@brian:{self.hs.hostname}"
 | 
			
		||||
        chris = f"@chris:{self.hs.hostname}"
 | 
			
		||||
 | 
			
		||||
        # Have three users upload fallback keys for two devices.
 | 
			
		||||
        fallback_keys = {
 | 
			
		||||
            alice: {
 | 
			
		||||
                "alice_dev_1": {"alg1:k1": "fallback_key1"},
 | 
			
		||||
                "alice_dev_2": {"alg2:k2": "fallback_key2"},
 | 
			
		||||
            },
 | 
			
		||||
            brian: {
 | 
			
		||||
                "brian_dev_1": {"alg1:k3": "fallback_key3"},
 | 
			
		||||
                "brian_dev_2": {"alg2:k4": "fallback_key4"},
 | 
			
		||||
            },
 | 
			
		||||
            chris: {
 | 
			
		||||
                "chris_dev_1": {"alg1:k5": "fallback_key5"},
 | 
			
		||||
                "chris_dev_2": {"alg2:k6": "fallback_key6"},
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        for user_id, devices in fallback_keys.items():
 | 
			
		||||
            for device_id, key_dict in devices.items():
 | 
			
		||||
                self.get_success(
 | 
			
		||||
                    self.handler.upload_keys_for_user(
 | 
			
		||||
                        user_id,
 | 
			
		||||
                        device_id,
 | 
			
		||||
                        {"fallback_keys": key_dict},
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        # Each device should have an unused fallback key.
 | 
			
		||||
        for user_id, devices in fallback_keys.items():
 | 
			
		||||
            for device_id in devices:
 | 
			
		||||
                fallback_res = self.get_success(
 | 
			
		||||
                    self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
 | 
			
		||||
                )
 | 
			
		||||
                expected_algorithm_name = f"alg{device_id[-1]}"
 | 
			
		||||
                self.assertEqual(fallback_res, [expected_algorithm_name])
 | 
			
		||||
 | 
			
		||||
        # Claim the fallback key for one device per user.
 | 
			
		||||
        claim_res = self.get_success(
 | 
			
		||||
            self.handler.claim_one_time_keys(
 | 
			
		||||
                {
 | 
			
		||||
                    alice: {"alice_dev_1": {"alg1": 1}},
 | 
			
		||||
                    brian: {"brian_dev_2": {"alg2": 1}},
 | 
			
		||||
                    chris: {"chris_dev_2": {"alg2": 1}},
 | 
			
		||||
                },
 | 
			
		||||
                self.requester,
 | 
			
		||||
                timeout=None,
 | 
			
		||||
                always_include_fallback_keys=False,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        expected_claims = {
 | 
			
		||||
            alice: {"alice_dev_1": {"alg1:k1": "fallback_key1"}},
 | 
			
		||||
            brian: {"brian_dev_2": {"alg2:k4": "fallback_key4"}},
 | 
			
		||||
            chris: {"chris_dev_2": {"alg2:k6": "fallback_key6"}},
 | 
			
		||||
        }
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            claim_res,
 | 
			
		||||
            {"failures": {}, "one_time_keys": expected_claims},
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for user_id, devices in fallback_keys.items():
 | 
			
		||||
            for device_id in devices:
 | 
			
		||||
                fallback_res = self.get_success(
 | 
			
		||||
                    self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
 | 
			
		||||
                )
 | 
			
		||||
                # Claimed fallback keys should no longer show up as unused.
 | 
			
		||||
                # Unclaimed fallback keys should still be unused.
 | 
			
		||||
                if device_id in expected_claims[user_id]:
 | 
			
		||||
                    self.assertEqual(fallback_res, [])
 | 
			
		||||
                else:
 | 
			
		||||
                    expected_algorithm_name = f"alg{device_id[-1]}"
 | 
			
		||||
                    self.assertEqual(fallback_res, [expected_algorithm_name])
 | 
			
		||||
 | 
			
		||||
    def test_fallback_key_always_returned(self) -> None:
 | 
			
		||||
        local_user = "@boris:" + self.hs.hostname
 | 
			
		||||
        device_id = "xyz"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue