diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 0c797f9f3e..203f50f07d 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -239,7 +239,6 @@ class DeviceStore(SQLBaseStore): def update_remote_device_list_cache_entry(self, user_id, device_id, content, stream_id): """Updates a single user's device in the cache. - If the content is null, delete the device from the cache. """ return self.runInteraction( "update_remote_device_list_cache_entry", @@ -249,7 +248,7 @@ class DeviceStore(SQLBaseStore): def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, content, stream_id): - if content is None: + if content.get("deleted"): self._simple_delete_txn( txn, table="device_lists_remote_cache", @@ -409,12 +408,15 @@ class DeviceStore(SQLBaseStore): prev_id = stream_id - key_json = device.get("key_json", None) - if key_json: - result["keys"] = json.loads(key_json) - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name + if device is not None: + key_json = device.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + else: + result["deleted"] = True results.append(result) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index f61553cec8..6c28719420 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -74,7 +74,7 @@ class EndToEndKeyStore(SQLBaseStore): include_all_devices (bool): whether to include entries for devices that don't have device keys include_deleted_devices (bool): whether to include null entries for - devices which no longer exist (but where in the query_list) + devices which no longer exist (but were in the query_list) Returns: Dict mapping from user-id to dict mapping from device_id to dict containing "key_json", "device_display_name". @@ -84,28 +84,25 @@ class EndToEndKeyStore(SQLBaseStore): results = yield self.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, - query_list, include_all_devices, + query_list, include_all_devices, include_deleted_devices, ) - if include_deleted_devices: - deleted_devices = set(query_list) - for user_id, device_keys in iteritems(results): for device_id, device_info in iteritems(device_keys): - if include_deleted_devices: - deleted_devices -= (user_id, device_id) device_info["keys"] = json.loads(device_info.pop("key_json")) - if include_deleted_devices: - for user_id, device_id in deleted_devices: - results.setdefault(user_id, {})[device_id] = None - defer.returnValue(results) - def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): + def _get_e2e_device_keys_txn( + self, txn, query_list, include_all_devices=False, + include_deleted_devices=False, + ): query_clauses = [] query_params = [] + if include_deleted_devices: + deleted_devices = set(query_list) + for (user_id, device_id) in query_list: query_clause = "user_id = ?" query_params.append(user_id) @@ -133,8 +130,14 @@ class EndToEndKeyStore(SQLBaseStore): result = {} for row in rows: + if include_deleted_devices: + deleted_devices.remove((row["user_id"], row["device_id"])) result.setdefault(row["user_id"], {})[row["device_id"]] = row + if include_deleted_devices: + for user_id, device_id in deleted_devices: + result.setdefault(user_id, {})[device_id] = None + return result @defer.inlineCallbacks