# -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd # Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, Tuple from canonicaljson import encode_canonical_json, json from twisted.enterprise.adbapi import Connection from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter class EndToEndKeyWorkerStore(SQLBaseStore): @trace @defer.inlineCallbacks def get_e2e_device_keys( self, query_list, include_all_devices=False, include_deleted_devices=False ): """Fetch a list of device keys. Args: query_list(list): List of pairs of user_ids and device_ids. include_all_devices (bool): whether to include entries for devices that don't have device keys include_deleted_devices (bool): whether to include null entries for devices which no longer exist (but were in the query_list). This option only takes effect if include_all_devices is true. Returns: Dict mapping from user-id to dict mapping from device_id to key data. The key data will be a dict in the same format as the DeviceKeys type returned by POST /_matrix/client/r0/keys/query. """ set_tag("query_list", query_list) if not query_list: return {} results = yield self.db_pool.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, include_all_devices, include_deleted_devices, ) # Build the result structure, un-jsonify the results, and add the # "unsigned" section rv = {} for user_id, device_keys in results.items(): rv[user_id] = {} for device_id, device_info in device_keys.items(): r = db_to_json(device_info.pop("key_json")) r["unsigned"] = {} display_name = device_info["device_display_name"] if display_name is not None: r["unsigned"]["device_display_name"] = display_name if "signatures" in device_info: for sig_user_id, sigs in device_info["signatures"].items(): r.setdefault("signatures", {}).setdefault( sig_user_id, {} ).update(sigs) rv[user_id][device_id] = r return rv @trace def _get_e2e_device_keys_txn( self, txn, query_list, include_all_devices=False, include_deleted_devices=False ): set_tag("include_all_devices", include_all_devices) set_tag("include_deleted_devices", include_deleted_devices) query_clauses = [] query_params = [] signature_query_clauses = [] signature_query_params = [] if include_all_devices is False: include_deleted_devices = False if include_deleted_devices: deleted_devices = set(query_list) for (user_id, device_id) in query_list: query_clause = "user_id = ?" query_params.append(user_id) signature_query_clause = "target_user_id = ?" signature_query_params.append(user_id) if device_id is not None: query_clause += " AND device_id = ?" query_params.append(device_id) signature_query_clause += " AND target_device_id = ?" signature_query_params.append(device_id) signature_query_clause += " AND user_id = ?" signature_query_params.append(user_id) query_clauses.append(query_clause) signature_query_clauses.append(signature_query_clause) sql = ( "SELECT user_id, device_id, " " d.display_name AS device_display_name, " " k.key_json" " FROM devices d" " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" " WHERE %s AND NOT d.hidden" ) % ( "LEFT" if include_all_devices else "INNER", " OR ".join("(" + q + ")" for q in query_clauses), ) txn.execute(sql, query_params) rows = self.db_pool.cursor_to_dict(txn) result = {} for row in rows: if include_deleted_devices: deleted_devices.remove((row["user_id"], row["device_id"])) result.setdefault(row["user_id"], {})[row["device_id"]] = row if include_deleted_devices: for user_id, device_id in deleted_devices: result.setdefault(user_id, {})[device_id] = None # get signatures on the device signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % ( " OR ".join("(" + q + ")" for q in signature_query_clauses) ) txn.execute(signature_sql, signature_query_params) rows = self.db_pool.cursor_to_dict(txn) # add each cross-signing signature to the correct device in the result dict. for row in rows: signing_user_id = row["user_id"] signing_key_id = row["key_id"] target_user_id = row["target_user_id"] target_device_id = row["target_device_id"] signature = row["signature"] target_user_result = result.get(target_user_id) if not target_user_result: continue target_device_result = target_user_result.get(target_device_id) if not target_device_result: # note that target_device_result will be None for deleted devices. continue target_device_signatures = target_device_result.setdefault("signatures", {}) signing_user_signatures = target_device_signatures.setdefault( signing_user_id, {} ) signing_user_signatures[signing_key_id] = signature log_kv(result) return result @defer.inlineCallbacks def get_e2e_one_time_keys(self, user_id, device_id, key_ids): """Retrieve a number of one-time keys for a user Args: user_id(str): id of user to get keys for device_id(str): id of device to get keys for key_ids(list[str]): list of key ids (excluding algorithm) to retrieve Returns: deferred resolving to Dict[(str, str), str]: map from (algorithm, key_id) to json string for key """ rows = yield self.db_pool.simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, retcols=("algorithm", "key_id", "key_json"), keyvalues={"user_id": user_id, "device_id": device_id}, desc="add_e2e_one_time_keys_check", ) result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) return result @defer.inlineCallbacks def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): """Insert some new one time keys for a device. Errors if any of the keys already exist. Args: user_id(str): id of user to get keys for device_id(str): id of device to get keys for time_now(long): insertion time to record (ms since epoch) new_keys(iterable[(str, str, str)]: keys to add - each a tuple of (algorithm, key_id, key json) """ def _add_e2e_one_time_keys(txn): set_tag("user_id", user_id) set_tag("device_id", device_id) set_tag("new_keys", new_keys) # We are protected from race between lookup and insertion due to # a unique constraint. If there is a race of two calls to # `add_e2e_one_time_keys` then they'll conflict and we will only # insert one set. self.db_pool.simple_insert_many_txn( txn, table="e2e_one_time_keys_json", values=[ { "user_id": user_id, "device_id": device_id, "algorithm": algorithm, "key_id": key_id, "ts_added_ms": time_now, "key_json": json_bytes, } for algorithm, key_id, json_bytes in new_keys ], ) self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) yield self.db_pool.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) @cached(max_entries=10000) def count_e2e_one_time_keys(self, user_id, device_id): """ Count the number of one time keys the server has for a device Returns: Dict mapping from algorithm to number of keys for that algorithm. """ def _count_e2e_one_time_keys(txn): sql = ( "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ?" " GROUP BY algorithm" ) txn.execute(sql, (user_id, device_id)) result = {} for algorithm, key_count in txn: result[algorithm] = key_count return result return self.db_pool.runInteraction( "count_e2e_one_time_keys", _count_e2e_one_time_keys ) @defer.inlineCallbacks def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): """Returns a user's cross-signing key. Args: user_id (str): the user whose key is being requested key_type (str): the type of key that is being requested: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on the self-signing key will be included in the result Returns: dict of the key data or None if not found """ res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) user_keys = res.get(user_id) if not user_keys: return None return user_keys.get(key_type) @cached(num_args=1) def _get_bare_e2e_cross_signing_keys(self, user_id): """Dummy function. Only used to make a cache for _get_bare_e2e_cross_signing_keys_bulk. """ raise NotImplementedError() @cachedList( cached_method_name="_get_bare_e2e_cross_signing_keys", list_name="user_ids", num_args=1, ) def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: List[str] ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. Args: user_ids (list[str]): the users whose keys are being requested Returns: dict[str, dict[str, dict]]: mapping from user ID to key type to key data. If a user's cross-signing keys were not found, either their user ID will not be in the dict, or their user ID will map to None. """ return self.db_pool.runInteraction( "get_bare_e2e_cross_signing_keys_bulk", self._get_bare_e2e_cross_signing_keys_bulk_txn, user_ids, ) def _get_bare_e2e_cross_signing_keys_bulk_txn( self, txn: Connection, user_ids: List[str], ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. Args: txn (twisted.enterprise.adbapi.Connection): db connection user_ids (list[str]): the users whose keys are being requested Returns: dict[str, dict[str, dict]]: mapping from user ID to key type to key data. If a user's cross-signing keys were not found, their user ID will not be in the dict. """ result = {} for user_chunk in batch_iter(user_ids, 100): clause, params = make_in_list_sql_clause( txn.database_engine, "k.user_id", user_chunk ) sql = ( """ SELECT k.user_id, k.keytype, k.keydata, k.stream_id FROM e2e_cross_signing_keys k INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id FROM e2e_cross_signing_keys GROUP BY user_id, keytype) s USING (user_id, stream_id, keytype) WHERE """ + clause ) txn.execute(sql, params) rows = self.db_pool.cursor_to_dict(txn) for row in rows: user_id = row["user_id"] key_type = row["keytype"] key = db_to_json(row["keydata"]) user_info = result.setdefault(user_id, {}) user_info[key_type] = key return result def _get_e2e_cross_signing_signatures_txn( self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str, ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing signatures made by a user on a set of keys. Args: txn (twisted.enterprise.adbapi.Connection): db connection keys (dict[str, dict[str, dict]]): a map of user ID to key type to key data. This dict will be modified to add signatures. from_user_id (str): fetch the signatures made by this user Returns: dict[str, dict[str, dict]]: mapping from user ID to key type to key data. The return value will be the same as the keys argument, with the modifications included. """ # find out what cross-signing keys (a.k.a. devices) we need to get # signatures for. This is a map of (user_id, device_id) to key type # (device_id is the key's public part). devices = {} for user_id, user_info in keys.items(): if user_info is None: continue for key_type, key in user_info.items(): device_id = None for k in key["keys"].values(): device_id = k devices[(user_id, device_id)] = key_type for batch in batch_iter(devices.keys(), size=100): sql = """ SELECT target_user_id, target_device_id, key_id, signature FROM e2e_cross_signing_signatures WHERE user_id = ? AND (%s) """ % ( " OR ".join( "(target_user_id = ? AND target_device_id = ?)" for _ in batch ) ) query_params = [from_user_id] for item in batch: # item is a (user_id, device_id) tuple query_params.extend(item) txn.execute(sql, query_params) rows = self.db_pool.cursor_to_dict(txn) # and add the signatures to the appropriate keys for row in rows: key_id = row["key_id"] target_user_id = row["target_user_id"] target_device_id = row["target_device_id"] key_type = devices[(target_user_id, target_device_id)] # We need to copy everything, because the result may have come # from the cache. dict.copy only does a shallow copy, so we # need to recursively copy the dicts that will be modified. user_info = keys[target_user_id] = keys[target_user_id].copy() target_user_key = user_info[key_type] = user_info[key_type].copy() if "signatures" in target_user_key: signatures = target_user_key["signatures"] = target_user_key[ "signatures" ].copy() if from_user_id in signatures: user_sigs = signatures[from_user_id] = signatures[from_user_id] user_sigs[key_id] = row["signature"] else: signatures[from_user_id] = {key_id: row["signature"]} else: target_user_key["signatures"] = { from_user_id: {key_id: row["signature"]} } return keys @defer.inlineCallbacks def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: str = None ) -> defer.Deferred: """Returns the cross-signing keys for a set of users. Args: user_ids (list[str]): the users whose keys are being requested from_user_id (str): if specified, signatures made by this user on the self-signing keys will be included in the result Returns: Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to key data. If a user's cross-signing keys were not found, either their user ID will not be in the dict, or their user ID will map to None. """ result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) if from_user_id: result = yield self.db_pool.runInteraction( "get_e2e_cross_signing_signatures", self._get_e2e_cross_signing_signatures_txn, result, from_user_id, ) return result async def get_all_user_signature_changes_for_remotes( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, tuple]], int, bool]: """Get updates for groups replication stream. Note that the user signature stream represents when a user signs their device with their user-signing key, which is not published to other users or servers, so no `destination` is needed in the returned list. However, this is needed to poke workers. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def _get_all_user_signature_changes_for_remotes_txn(txn): sql = """ SELECT stream_id, from_user_id AS user_id FROM user_signature_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) updates = [(row[0], (row[1:])) for row in txn] limited = False upto_token = current_id if len(updates) >= limit: upto_token = updates[-1][0] limited = True return updates, upto_token, limited return await self.db_pool.runInteraction( "get_all_user_signature_changes_for_remotes", _get_all_user_signature_changes_for_remotes_txn, ) class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): """Stores device keys for a device. Returns whether there was a change or the keys were already in the database. """ def _set_e2e_device_keys_txn(txn): set_tag("user_id", user_id) set_tag("device_id", device_id) set_tag("time_now", time_now) set_tag("device_keys", device_keys) old_key_json = self.db_pool.simple_select_one_onecol_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="key_json", allow_none=True, ) # In py3 we need old_key_json to match new_key_json type. The DB # returns unicode while encode_canonical_json returns bytes. new_key_json = encode_canonical_json(device_keys).decode("utf-8") if old_key_json == new_key_json: log_kv({"Message": "Device key already stored."}) return False self.db_pool.simple_upsert_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, values={"ts_added_ms": time_now, "key_json": new_key_json}, ) log_kv({"message": "Device keys stored."}) return True return self.db_pool.runInteraction( "set_e2e_device_keys", _set_e2e_device_keys_txn ) def claim_e2e_one_time_keys(self, query_list): """Take a list of one time keys out of the database""" @trace def _claim_e2e_one_time_keys(txn): sql = ( "SELECT key_id, key_json FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" " LIMIT 1" ) result = {} delete = [] for user_id, device_id, algorithm in query_list: user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) txn.execute(sql, (user_id, device_id, algorithm)) for key_id, key_json in txn: device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) sql = ( "DELETE FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" " AND key_id = ?" ) for user_id, device_id, algorithm, key_id in delete: log_kv( { "message": "Executing claim e2e_one_time_keys transaction on database." } ) txn.execute(sql, (user_id, device_id, algorithm, key_id)) log_kv({"message": "finished executing and invalidating cache"}) self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) return result return self.db_pool.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) def delete_e2e_keys_by_device(self, user_id, device_id): def delete_e2e_keys_by_device_txn(txn): log_kv( { "message": "Deleting keys for device", "device_id": device_id, "user_id": user_id, } ) self.db_pool.simple_delete_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, ) self.db_pool.simple_delete_txn( txn, table="e2e_one_time_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, ) self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) return self.db_pool.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): """Set a user's cross-signing key. Args: txn (twisted.enterprise.adbapi.Connection): db connection user_id (str): the user to set the signing key for key_type (str): the type of key that is being set: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key key (dict): the key data """ # the 'key' dict will look something like: # { # "user_id": "@alice:example.com", # "usage": ["self_signing"], # "keys": { # "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key", # }, # "signatures": { # "@alice:example.com": { # "ed25519:base64+master+public+key": "base64+signature" # } # } # } # The "keys" property must only have one entry, which will be the public # key, so we just grab the first value in there pubkey = next(iter(key["keys"].values())) # The cross-signing keys need to occupy the same namespace as devices, # since signatures are identified by device ID. So add an entry to the # device table to make sure that we don't have a collision with device # IDs. # We only need to do this for local users, since remote servers should be # responsible for checking this for their own users. if self.hs.is_mine_id(user_id): self.db_pool.simple_insert_txn( txn, "devices", values={ "user_id": user_id, "device_id": pubkey, "display_name": key_type + " signing key", "hidden": True, }, ) # and finally, store the key itself with self._cross_signing_id_gen.get_next() as stream_id: self.db_pool.simple_insert_txn( txn, "e2e_cross_signing_keys", values={ "user_id": user_id, "keytype": key_type, "keydata": json.dumps(key), "stream_id": stream_id, }, ) self._invalidate_cache_and_stream( txn, self._get_bare_e2e_cross_signing_keys, (user_id,) ) def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. Args: user_id (str): the user to set the user-signing key for key_type (str): the type of cross-signing key to set key (dict): the key data """ return self.db_pool.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, user_id, key_type, key, ) def store_e2e_cross_signing_signatures(self, user_id, signatures): """Stores cross-signing signatures. Args: user_id (str): the user who made the signatures signatures (iterable[SignatureListItem]): signatures to add """ return self.db_pool.simple_insert_many( "e2e_cross_signing_signatures", [ { "user_id": user_id, "key_id": item.signing_key_id, "target_user_id": item.target_user_id, "target_device_id": item.target_device_id, "signature": item.signature, } for item in signatures ], "add_e2e_signing_key", )