diff --git a/changelog.d/7594.bugfix b/changelog.d/7594.bugfix new file mode 100644 index 0000000000..f0c067e184 --- /dev/null +++ b/changelog.d/7594.bugfix @@ -0,0 +1 @@ +Fix a bug causing the cross-signing keys to be ignored when resyncing a device list. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 2cbb695bb1..230d170258 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Any, Dict, Optional from six import iteritems, itervalues @@ -30,7 +31,11 @@ from synapse.api.errors import ( ) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import RoomStreamToken, get_domain_from_id +from synapse.types import ( + RoomStreamToken, + get_domain_from_id, + get_verify_key_from_cross_signing_key, +) from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -795,6 +800,13 @@ class DeviceListUpdater(object): stream_id = result["stream_id"] devices = result["devices"] + # Get the master key and the self-signing key for this user if provided in the + # response (None if not in the response). + # The response will not contain the user signing key, as this key is only used by + # its owner, thus it doesn't make sense to send it over federation. + master_key = result.get("master_key") + self_signing_key = result.get("self_signing_key") + # If the remote server has more than ~1000 devices for this user # we assume that something is going horribly wrong (e.g. a bot # that logs in and creates a new device every time it tries to @@ -824,6 +836,13 @@ class DeviceListUpdater(object): yield self.store.update_remote_device_list_cache(user_id, devices, stream_id) device_ids = [device["device_id"] for device in devices] + + # Handle cross-signing keys. + cross_signing_device_ids = yield self.process_cross_signing_key_update( + user_id, master_key, self_signing_key, + ) + device_ids = device_ids + cross_signing_device_ids + yield self.device_handler.notify_device_update(user_id, device_ids) # We clobber the seen updates since we've re-synced from a given @@ -831,3 +850,40 @@ class DeviceListUpdater(object): self._seen_updates[user_id] = {stream_id} defer.returnValue(result) + + @defer.inlineCallbacks + def process_cross_signing_key_update( + self, + user_id: str, + master_key: Optional[Dict[str, Any]], + self_signing_key: Optional[Dict[str, Any]], + ) -> list: + """Process the given new master and self-signing key for the given remote user. + + Args: + user_id: The ID of the user these keys are for. + master_key: The dict of the cross-signing master key as returned by the + remote server. + self_signing_key: The dict of the cross-signing self-signing key as returned + by the remote server. + + Return: + The device IDs for the given keys. + """ + device_ids = [] + + if master_key: + yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + _, verify_key = get_verify_key_from_cross_signing_key(master_key) + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + device_ids.append(verify_key.version) + if self_signing_key: + yield self.store.set_e2e_cross_signing_key( + user_id, "self_signing", self_signing_key + ) + _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) + device_ids.append(verify_key.version) + + return device_ids diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 8f1bc0323c..774a252619 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1291,6 +1291,7 @@ class SigningKeyEduUpdater(object): """ device_handler = self.e2e_keys_handler.device_handler + device_list_updater = device_handler.device_list_updater with (yield self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) @@ -1303,22 +1304,9 @@ class SigningKeyEduUpdater(object): logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - if master_key: - yield self.store.set_e2e_cross_signing_key( - user_id, "master", master_key - ) - _, verify_key = get_verify_key_from_cross_signing_key(master_key) - # verify_key is a VerifyKey from signedjson, which uses - # .version to denote the portion of the key ID after the - # algorithm and colon, which is the device ID - device_ids.append(verify_key.version) - if self_signing_key: - yield self.store.set_e2e_cross_signing_key( - user_id, "self_signing", self_signing_key - ) - _, verify_key = get_verify_key_from_cross_signing_key( - self_signing_key - ) - device_ids.append(verify_key.version) + new_device_ids = yield device_list_updater.process_cross_signing_key_update( + user_id, master_key, self_signing_key, + ) + device_ids = device_ids + new_device_ids yield device_handler.notify_device_update(user_id, device_ids) diff --git a/tests/test_federation.py b/tests/test_federation.py index c5099dd039..c662195eec 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -206,3 +206,59 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # list. self.reactor.advance(30) self.assertEqual(self.resync_attempts, 2) + + def test_cross_signing_keys_retry(self): + """Tests that resyncing a device list correctly processes cross-signing keys from + the remote server. + """ + remote_user_id = "@john:test_remote" + remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" + remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" + + # Register mock device list retrieval on the federation client. + federation_client = self.homeserver.get_federation_client() + federation_client.query_user_devices = Mock( + return_value={ + "user_id": remote_user_id, + "stream_id": 1, + "devices": [], + "master_key": { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + }, + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + remote_self_signing_key: remote_self_signing_key + }, + }, + } + ) + + # Resync the device list. + device_handler = self.homeserver.get_device_handler() + self.get_success( + device_handler.device_list_updater.user_device_resync(remote_user_id), + ) + + # Retrieve the cross-signing keys for this user. + keys = self.get_success( + self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]), + ) + self.assertTrue(remote_user_id in keys) + + # Check that the master key is the one returned by the mock. + master_key = keys[remote_user_id]["master"] + self.assertEqual(len(master_key["keys"]), 1) + self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys()) + self.assertTrue(remote_master_key in master_key["keys"].values()) + + # Check that the self-signing key is the one returned by the mock. + self_signing_key = keys[remote_user_id]["self_signing"] + self.assertEqual(len(self_signing_key["keys"]), 1) + self.assertTrue( + "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), + ) + self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values())