Make _get_e2e_device_keys_and_signatures_txn return an attrs (#8224)

this makes it a bit clearer what's going on.
pull/8232/head
Richard van der Hoff 2020-09-02 11:47:26 +01:00 committed by GitHub
parent b939251c37
commit abeab964d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 20 deletions

1
changelog.d/8224.misc Normal file
View File

@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

View File

@ -293,17 +293,17 @@ class DeviceWorkerStore(SQLBaseStore):
prev_id = stream_id
if device is not None:
key_json = device.get("key_json", None)
key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)
if "signatures" in device:
for sig_user_id, sigs in device["signatures"].items():
if device.signatures:
for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
device_display_name = device.get("device_display_name", None)
device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:

View File

@ -17,6 +17,7 @@
import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection
@ -33,6 +34,21 @@ if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
@attr.s
class DeviceKeyLookupResult:
"""The type returned by _get_e2e_device_keys_and_signatures_txn"""
display_name = attr.ib(type=Optional[str])
# the key data from e2e_device_keys_json. Typically includes fields like
# "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
# key) and "signatures" (a signature of the structure by the ed25519 key)
key_json = attr.ib(type=Optional[str])
# cross-signing sigs
signatures = attr.ib(type=Optional[Dict], default=None)
class EndToEndKeyWorkerStore(SQLBaseStore):
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
@ -61,17 +77,17 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for device_id, device in user_devices.items():
result = {"device_id": device_id}
key_json = device.get("key_json", None)
key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)
if "signatures" in device:
for sig_user_id, sigs in device["signatures"].items():
if device.signatures:
for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
device_display_name = device.get("device_display_name", None)
device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
@ -109,13 +125,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
r = db_to_json(device_info.pop("key_json"))
r = db_to_json(device_info.key_json)
r["unsigned"] = {}
display_name = device_info["device_display_name"]
display_name = device_info.display_name
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
if "signatures" in device_info:
for sig_user_id, sigs in device_info["signatures"].items():
if device_info.signatures:
for sig_user_id, sigs in device_info.signatures.items():
r.setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
@ -126,7 +142,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
@trace
def _get_e2e_device_keys_and_signatures_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[Dict]]]:
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)
@ -161,7 +177,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
sql = (
"SELECT user_id, device_id, "
" d.display_name AS device_display_name, "
" d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
@ -172,13 +188,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(sql, query_params)
rows = self.db_pool.cursor_to_dict(txn)
result = {}
for row in rows:
result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
deleted_devices.remove((row["user_id"], row["device_id"]))
result.setdefault(row["user_id"], {})[row["device_id"]] = row
deleted_devices.remove((user_id, device_id))
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
display_name, key_json
)
if include_deleted_devices:
for user_id, device_id in deleted_devices:
@ -209,7 +226,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# note that target_device_result will be None for deleted devices.
continue
target_device_signatures = target_device_result.setdefault("signatures", {})
target_device_signatures = target_device_result.signatures
if target_device_signatures is None:
target_device_signatures = target_device_result.signatures = {}
signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)