diff --git a/changelog.d/11382.misc b/changelog.d/11382.misc new file mode 100644 index 0000000000..d812ef309e --- /dev/null +++ b/changelog.d/11382.misc @@ -0,0 +1 @@ +Keep fallback key marked as used if it's re-uploaded. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index a95ac34f09..b06c1dc45b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -408,29 +408,58 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): fallback_keys: the keys to set. This is a map from key ID (which is of the form "algorithm:id") to key data. """ + await self.db_pool.runInteraction( + "set_e2e_fallback_keys_txn", + self._set_e2e_fallback_keys_txn, + user_id, + device_id, + fallback_keys, + ) + + await self.invalidate_cache_and_stream( + "get_e2e_unused_fallback_key_types", (user_id, device_id) + ) + + def _set_e2e_fallback_keys_txn( + self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict + ) -> None: # fallback_keys will usually only have one item in it, so using a for # loop (as opposed to calling simple_upsert_many_txn) won't be too bad # FIXME: make sure that only one key per algorithm is uploaded for key_id, fallback_key in fallback_keys.items(): algorithm, key_id = key_id.split(":", 1) - await self.db_pool.simple_upsert( - "e2e_fallback_keys_json", + old_key_json = self.db_pool.simple_select_one_onecol_txn( + txn, + table="e2e_fallback_keys_json", keyvalues={ "user_id": user_id, "device_id": device_id, "algorithm": algorithm, }, - values={ - "key_id": key_id, - "key_json": json_encoder.encode(fallback_key), - "used": False, - }, - desc="set_e2e_fallback_key", + retcol="key_json", + allow_none=True, ) - await self.invalidate_cache_and_stream( - "get_e2e_unused_fallback_key_types", (user_id, device_id) - ) + new_key_json = encode_canonical_json(fallback_key).decode("utf-8") + + # If the uploaded key is the same as the current fallback key, + # don't do anything. This prevents marking the key as unused if it + # was already used. + if old_key_json != new_key_json: + self.db_pool.simple_upsert_txn( + txn, + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + values={ + "key_id": key_id, + "key_json": json_encoder.encode(fallback_key), + "used": False, + }, + ) @cached(max_entries=10000) async def get_e2e_unused_fallback_key_types( diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 0c3b86fda9..f0723892e4 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -162,6 +162,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): local_user = "@boris:" + self.hs.hostname device_id = "xyz" fallback_key = {"alg1:k1": "key1"} + fallback_key2 = {"alg1:k2": "key2"} otk = {"alg1:k2": "key2"} # we shouldn't have any unused fallback keys yet @@ -213,6 +214,35 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, ) + # re-uploading the same fallback key should still result in no unused fallback + # keys + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"org.matrix.msc2732.fallback_keys": fallback_key}, + ) + ) + + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(res, []) + + # uploading a new fallback key should result in an unused fallback key + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"org.matrix.msc2732.fallback_keys": fallback_key2}, + ) + ) + + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(res, ["alg1"]) + # if the user uploads a one-time key, the next claim should fetch the # one-time key, and then go back to the fallback self.get_success( @@ -238,7 +268,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.assertEqual( res, - {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}}, ) def test_replace_master_key(self):