Add support for MSC2732: olm fallback keys (#8312)
parent
a024461130
commit
3cd78bbe9e
|
@ -0,0 +1 @@
|
|||
Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)).
|
|
@ -90,6 +90,7 @@ BOOLEAN_COLUMNS = {
|
|||
"room_stats_state": ["is_federatable"],
|
||||
"local_media_repository": ["safe_from_quarantine"],
|
||||
"users": ["shadow_banned"],
|
||||
"e2e_fallback_keys_json": ["used"],
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -496,6 +496,22 @@ class E2eKeysHandler:
|
|||
log_kv(
|
||||
{"message": "Did not update one_time_keys", "reason": "no keys given"}
|
||||
)
|
||||
fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
|
||||
if fallback_keys and isinstance(fallback_keys, dict):
|
||||
log_kv(
|
||||
{
|
||||
"message": "Updating fallback_keys for device.",
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
}
|
||||
)
|
||||
await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
|
||||
elif fallback_keys:
|
||||
log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
|
||||
else:
|
||||
log_kv(
|
||||
{"message": "Did not update fallback_keys", "reason": "no keys given"}
|
||||
)
|
||||
|
||||
# the device should have been registered already, but it may have been
|
||||
# deleted due to a race with a DELETE request. Or we may be using an
|
||||
|
|
|
@ -201,6 +201,8 @@ class SyncResult:
|
|||
device_lists: List of user_ids whose devices have changed
|
||||
device_one_time_keys_count: Dict of algorithm to count for one time keys
|
||||
for this device
|
||||
device_unused_fallback_key_types: List of key types that have an unused fallback
|
||||
key
|
||||
groups: Group updates, if any
|
||||
"""
|
||||
|
||||
|
@ -213,6 +215,7 @@ class SyncResult:
|
|||
to_device = attr.ib(type=List[JsonDict])
|
||||
device_lists = attr.ib(type=DeviceLists)
|
||||
device_one_time_keys_count = attr.ib(type=JsonDict)
|
||||
device_unused_fallback_key_types = attr.ib(type=List[str])
|
||||
groups = attr.ib(type=Optional[GroupsSyncResult])
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
|
@ -1014,10 +1017,14 @@ class SyncHandler:
|
|||
logger.debug("Fetching OTK data")
|
||||
device_id = sync_config.device_id
|
||||
one_time_key_counts = {} # type: JsonDict
|
||||
unused_fallback_key_types = [] # type: List[str]
|
||||
if device_id:
|
||||
one_time_key_counts = await self.store.count_e2e_one_time_keys(
|
||||
user_id, device_id
|
||||
)
|
||||
unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
|
||||
user_id, device_id
|
||||
)
|
||||
|
||||
logger.debug("Fetching group data")
|
||||
await self._generate_sync_entry_for_groups(sync_result_builder)
|
||||
|
@ -1041,6 +1048,7 @@ class SyncHandler:
|
|||
device_lists=device_lists,
|
||||
groups=sync_result_builder.groups,
|
||||
device_one_time_keys_count=one_time_key_counts,
|
||||
device_unused_fallback_key_types=unused_fallback_key_types,
|
||||
next_batch=sync_result_builder.now_token,
|
||||
)
|
||||
|
||||
|
|
|
@ -236,6 +236,7 @@ class SyncRestServlet(RestServlet):
|
|||
"leave": sync_result.groups.leave,
|
||||
},
|
||||
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
|
||||
"org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
|
||||
"next_batch": await sync_result.next_batch.to_string(self.store),
|
||||
}
|
||||
|
||||
|
|
|
@ -367,6 +367,57 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||
)
|
||||
|
||||
async def set_e2e_fallback_keys(
|
||||
self, user_id: str, device_id: str, fallback_keys: JsonDict
|
||||
) -> None:
|
||||
"""Set the user's e2e fallback keys.
|
||||
|
||||
Args:
|
||||
user_id: the user whose keys are being set
|
||||
device_id: the device whose keys are being set
|
||||
fallback_keys: the keys to set. This is a map from key ID (which is
|
||||
of the form "algorithm:id") to key data.
|
||||
"""
|
||||
# fallback_keys will usually only have one item in it, so using a for
|
||||
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
|
||||
# FIXME: make sure that only one key per algorithm is uploaded
|
||||
for key_id, fallback_key in fallback_keys.items():
|
||||
algorithm, key_id = key_id.split(":", 1)
|
||||
await self.db_pool.simple_upsert(
|
||||
"e2e_fallback_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm,
|
||||
},
|
||||
values={
|
||||
"key_id": key_id,
|
||||
"key_json": json_encoder.encode(fallback_key),
|
||||
"used": False,
|
||||
},
|
||||
desc="set_e2e_fallback_key",
|
||||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
async def get_e2e_unused_fallback_key_types(
|
||||
self, user_id: str, device_id: str
|
||||
) -> List[str]:
|
||||
"""Returns the fallback key types that have an unused key.
|
||||
|
||||
Args:
|
||||
user_id: the user whose keys are being queried
|
||||
device_id: the device whose keys are being queried
|
||||
|
||||
Returns:
|
||||
a list of key types
|
||||
"""
|
||||
return await self.db_pool.simple_select_onecol(
|
||||
"e2e_fallback_keys_json",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
|
||||
retcol="algorithm",
|
||||
desc="get_e2e_unused_fallback_key_types",
|
||||
)
|
||||
|
||||
async def get_e2e_cross_signing_key(
|
||||
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
|
||||
) -> Optional[dict]:
|
||||
|
@ -701,15 +752,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||
" LIMIT 1"
|
||||
)
|
||||
fallback_sql = (
|
||||
"SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
|
||||
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||
" LIMIT 1"
|
||||
)
|
||||
result = {}
|
||||
delete = []
|
||||
used_fallbacks = []
|
||||
for user_id, device_id, algorithm in query_list:
|
||||
user_result = result.setdefault(user_id, {})
|
||||
device_result = user_result.setdefault(device_id, {})
|
||||
txn.execute(sql, (user_id, device_id, algorithm))
|
||||
for key_id, key_json in txn:
|
||||
otk_row = txn.fetchone()
|
||||
if otk_row is not None:
|
||||
key_id, key_json = otk_row
|
||||
device_result[algorithm + ":" + key_id] = key_json
|
||||
delete.append((user_id, device_id, algorithm, key_id))
|
||||
else:
|
||||
# no one-time key available, so see if there's a fallback
|
||||
# key
|
||||
txn.execute(fallback_sql, (user_id, device_id, algorithm))
|
||||
fallback_row = txn.fetchone()
|
||||
if fallback_row is not None:
|
||||
key_id, key_json, used = fallback_row
|
||||
device_result[algorithm + ":" + key_id] = key_json
|
||||
if not used:
|
||||
used_fallbacks.append(
|
||||
(user_id, device_id, algorithm, key_id)
|
||||
)
|
||||
|
||||
# drop any one-time keys that were claimed
|
||||
sql = (
|
||||
"DELETE FROM e2e_one_time_keys_json"
|
||||
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||
|
@ -726,6 +799,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
# mark fallback keys as used
|
||||
for user_id, device_id, algorithm, key_id in used_fallbacks:
|
||||
self.db_pool.simple_update_txn(
|
||||
txn,
|
||||
"e2e_fallback_keys_json",
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm,
|
||||
"key_id": key_id,
|
||||
},
|
||||
{"used": True},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
@ -754,6 +844,14 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="e2e_fallback_keys_json",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
|
||||
user_id TEXT NOT NULL, -- The user this fallback key is for.
|
||||
device_id TEXT NOT NULL, -- The device this fallback key is for.
|
||||
algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
|
||||
key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
|
||||
key_json TEXT NOT NULL, -- The key as a JSON blob.
|
||||
used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
|
||||
CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
|
||||
);
|
|
@ -171,6 +171,71 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||
},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_fallback_key(self):
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
fallback_key = {"alg1:k1": "key1"}
|
||||
otk = {"alg1:k2": "key2"}
|
||||
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id,
|
||||
{"org.matrix.msc2732.fallback_keys": fallback_key},
|
||||
)
|
||||
)
|
||||
|
||||
# claiming an OTK when no OTKs are available should return the fallback
|
||||
# key
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
res,
|
||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
|
||||
)
|
||||
|
||||
# claiming an OTK again should return the same fallback key
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
res,
|
||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
|
||||
)
|
||||
|
||||
# if the user uploads a one-time key, the next claim should fetch the
|
||||
# one-time key, and then go back to the fallback
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": otk}
|
||||
)
|
||||
)
|
||||
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
|
||||
)
|
||||
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
res,
|
||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_replace_master_key(self):
|
||||
"""uploading a new signing key should make the old signing key unavailable"""
|
||||
|
|
Loading…
Reference in New Issue