diff --git a/changelog.d/8205.misc b/changelog.d/8205.misc new file mode 100644 index 0000000000..fb8fd83278 --- /dev/null +++ b/changelog.d/8205.misc @@ -0,0 +1 @@ + Refactor queries for device keys and cross-signatures. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d8def45e38..dfd1c78549 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -353,7 +353,7 @@ class E2eKeysHandler(object): # make sure that each queried user appears in the result dict result_dict[user_id] = {} - results = await self.store.get_e2e_device_keys(local_query) + results = await self.store.get_e2e_device_keys_for_cs_api(local_query) # Build the result structure for user_id, device_keys in results.items(): @@ -734,7 +734,7 @@ class E2eKeysHandler(object): # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to compare with what # was sent if the device was signed - devices = await self.store.get_e2e_device_keys([(user_id, None)]) + devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)]) if user_id not in devices: raise NotFoundError("No device keys found") diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index af0b85e2c9..50ecddf7fa 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -23,6 +23,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -33,17 +34,12 @@ if TYPE_CHECKING: class EndToEndKeyWorkerStore(SQLBaseStore): @trace - async def get_e2e_device_keys( - self, query_list, include_all_devices=False, include_deleted_devices=False - ): - """Fetch a list of device keys. + async def get_e2e_device_keys_for_cs_api( + self, query_list: List[Tuple[str, Optional[str]]] + ) -> Dict[str, Dict[str, JsonDict]]: + """Fetch a list of device keys, formatted suitably for the C/S API. Args: query_list(list): List of pairs of user_ids and device_ids. - include_all_devices (bool): whether to include entries for devices - that don't have device keys - include_deleted_devices (bool): whether to include null entries for - devices which no longer exist (but were in the query_list). - This option only takes effect if include_all_devices is true. Returns: Dict mapping from user-id to dict mapping from device_id to key data. The key data will be a dict in the same format as the @@ -54,11 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return {} results = await self.db_pool.runInteraction( - "get_e2e_device_keys", - self._get_e2e_device_keys_txn, - query_list, - include_all_devices, - include_deleted_devices, + "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, ) # Build the result structure, un-jsonify the results, and add the diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 261bf5b08b..3fc4bb13b6 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -37,7 +37,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): ) res = yield defer.ensureDeferred( - self.store.get_e2e_device_keys((("user", "device"),)) + self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) ) self.assertIn("user", res) self.assertIn("device", res["user"]) @@ -76,7 +76,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): ) res = yield defer.ensureDeferred( - self.store.get_e2e_device_keys((("user", "device"),)) + self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) ) self.assertIn("user", res) self.assertIn("device", res["user"]) @@ -108,7 +108,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): ) res = yield defer.ensureDeferred( - self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2"))) + self.store.get_e2e_device_keys_for_cs_api( + (("user1", "device1"), ("user2", "device2")) + ) ) self.assertIn("user1", res) self.assertIn("device1", res["user1"])