Support staged rollout of migration to Rust Crypto (#12184)

* Rust migration staged rollout

* Phased rollout unit tests
pull/28217/head
Valere 2024-01-31 16:52:23 +01:00 committed by GitHub
parent 73b16239a5
commit a5f9df5855
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 369 additions and 6 deletions

View File

@ -95,6 +95,7 @@
"highlight.js": "^11.3.1",
"html-entities": "^2.0.0",
"is-ip": "^3.1.0",
"js-xxhash": "^3.0.1",
"jszip": "^3.7.0",
"katex": "^0.16.0",
"linkify-element": "4.1.3",

View File

@ -0,0 +1,64 @@
/*
Copyright 2024 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import { test, expect } from "../../element-web-test";
import { logIntoElement } from "./utils";
test.describe("Migration of existing logins", () => {
test("Test migration of existing logins when rollout is 100%", async ({
page,
context,
app,
credentials,
homeserver,
}, workerInfo) => {
test.skip(workerInfo.project.name === "Rust Crypto", "This test only works with Rust crypto.");
await page.goto("/#/login");
let featureRustCrypto = false;
let stagedRolloutPercent = 0;
await context.route(`http://localhost:8080/config.json*`, async (route) => {
const json = {};
json["features"] = {
feature_rust_crypto: featureRustCrypto,
};
json["setting_defaults"] = {
"RustCrypto.staged_rollout_percent": stagedRolloutPercent,
};
await route.fulfill({ json });
});
await logIntoElement(page, homeserver, credentials);
await app.settings.openUserSettings("Help & About");
await expect(page.getByText("Crypto version: Olm")).toBeVisible();
featureRustCrypto = true;
await page.reload();
await app.settings.openUserSettings("Help & About");
await expect(page.getByText("Crypto version: Olm")).toBeVisible();
stagedRolloutPercent = 100;
await page.reload();
await app.settings.openUserSettings("Help & About");
await expect(page.getByText("Crypto version: Rust SDK")).toBeVisible();
});
});

View File

@ -18,15 +18,15 @@ limitations under the License.
*/
import {
ICreateClientOpts,
PendingEventOrdering,
RoomNameState,
RoomNameType,
EventTimeline,
EventTimelineSet,
ICreateClientOpts,
IStartClientOpts,
MatrixClient,
MemoryStore,
PendingEventOrdering,
RoomNameState,
RoomNameType,
TokenRefreshFunction,
} from "matrix-js-sdk/src/matrix";
import * as utils from "matrix-js-sdk/src/utils";
@ -53,6 +53,7 @@ import PlatformPeg from "./PlatformPeg";
import { formatList } from "./utils/FormattingUtils";
import SdkConfig from "./SdkConfig";
import { Features } from "./settings/Settings";
import { PhasedRolloutFeature } from "./utils/PhasedRolloutFeature";
export interface IMatrixClientCreds {
homeserverUrl: string;
@ -302,13 +303,34 @@ class MatrixClientPegClass implements IMatrixClientPeg {
throw new Error("createClient must be called first");
}
const useRustCrypto = SettingsStore.getValue(Features.RustCrypto);
let useRustCrypto = SettingsStore.getValue(Features.RustCrypto);
// We want the value that is set in the config.json for that web instance
const defaultUseRustCrypto = SettingsStore.getValueAt(SettingLevel.CONFIG, Features.RustCrypto);
const migrationPercent = SettingsStore.getValueAt(SettingLevel.CONFIG, "RustCrypto.staged_rollout_percent");
// If the default config is to use rust crypto, and the user is on legacy crypto,
// we want to check if we should migrate the current user.
if (!useRustCrypto && defaultUseRustCrypto && Number.isInteger(migrationPercent)) {
// The user is not on rust crypto, but the default stack is now rust; Let's check if we should migrate
// the current user to rust crypto.
try {
const stagedRollout = new PhasedRolloutFeature("RustCrypto.staged_rollout_percent", migrationPercent);
// Device id should not be null at that point, or init crypto will fail anyhow
const deviceId = this.matrixClient.getDeviceId()!;
// we use deviceId rather than userId because we don't particularly want all devices
// of a user to be migrated at the same time.
useRustCrypto = stagedRollout.isFeatureEnabled(deviceId);
} catch (e) {
logger.warn("Failed to create staged rollout feature for rust crypto migration", e);
}
}
// we want to make sure that the same crypto implementation is used throughout the lifetime of a device,
// so persist the setting at the device layer
// (At some point, we'll allow the user to *enable* the setting via labs, which will migrate their existing
// device to the rust-sdk implementation, but that won't change anything here).
await SettingsStore.setValue("feature_rust_crypto", null, SettingLevel.DEVICE, useRustCrypto);
await SettingsStore.setValue(Features.RustCrypto, null, SettingLevel.DEVICE, useRustCrypto);
// Now we can initialise the right crypto impl.
if (useRustCrypto) {

View File

@ -96,6 +96,7 @@ export enum Features {
VoiceBroadcastForceSmallChunks = "feature_voice_broadcast_force_small_chunks",
NotificationSettings2 = "feature_notification_settings2",
OidcNativeFlow = "feature_oidc_native_flow",
// If true, every new login will use the new rust crypto implementation
RustCrypto = "feature_rust_crypto",
}
@ -503,6 +504,13 @@ export const SETTINGS: { [setting: string]: ISetting } = {
default: false,
controller: new RustCryptoSdkController(),
},
// Must be set under `setting_defaults` in config.json.
// If set to 100 in conjunction with `feature_rust_crypto`, all existing users will migrate to the new crypto.
// Default is 0, meaning no existing users on legacy crypto will migrate.
"RustCrypto.staged_rollout_percent": {
supportedLevels: [SettingLevel.CONFIG],
default: 0,
},
"baseFontSize": {
displayName: _td("settings|appearance|font_size"),
supportedLevels: LEVELS_ACCOUNT_SETTINGS,

View File

@ -0,0 +1,63 @@
/*
Copyright 2024 New Vector Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import { xxHash32 } from "js-xxhash";
/**
* The PhasedRolloutFeature class is used to manage the phased rollout of a new feature.
*
* It uses a hash of the user's identifier and the feature name to determine if a feature is enabled for a specific user.
* The rollout percentage determines the probability that a user will be enabled for the feature.
* The feature will be enabled for all users if the rollout percentage is 100, and for no users if the percentage is 0.
* If a user is enabled for a feature at x% rollout, it will also be for any greater than x percent.
*
* The process ensures a uniform distribution of enabled features across users.
*
* @property featureName - The name of the feature to be rolled out.
* @property rolloutPercentage - The int percentage (0..100) of users for whom the feature should be enabled.
*/
export class PhasedRolloutFeature {
public readonly featureName: string;
private readonly rolloutPercentage: number;
private readonly seed: number;
public constructor(featureName: string, rolloutPercentage: number) {
this.featureName = featureName;
if (!Number.isInteger(rolloutPercentage) || rolloutPercentage < 0 || rolloutPercentage > 100) {
throw new Error("Rollout percentage must be an integer between 0 and 100");
}
this.rolloutPercentage = rolloutPercentage;
// We add the feature name for the seed to ensure that the hash is different for each feature
this.seed = Array.from(featureName).reduce((sum, char) => sum + char.charCodeAt(0), 0);
}
/**
* Returns true if the feature should be enabled for the given user.
* @param userIdentifier - Some unique identifier for the user, e.g. their user ID or device ID.
*/
public isFeatureEnabled(userIdentifier: string): boolean {
/*
* We use a hash function to convert the unique user ID string into an integer.
* This integer can then be used as a basis for deciding whether the user should have access to the new feature.
* We need some hash with good uniform distribution properties, security is not a concern here.
* We use xxHash32, which is fast and has good distribution properties.
*/
const hash = xxHash32(userIdentifier, this.seed);
// We use the hash modulo 100 to get a number between 0 and 99.
// Modulo is simple and effective and the distribution should be uniform enough for our purposes.
return hash % 100 < this.rolloutPercentage;
}
}

View File

@ -144,6 +144,117 @@ describe("MatrixClientPeg", () => {
expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, true);
});
describe("Rust staged rollout", () => {
function mockSettingStore(
userIsUsingRust: boolean,
newLoginShouldUseRust: boolean,
rolloutPercent: number | null,
) {
const originalGetValue = SettingsStore.getValue;
jest.spyOn(SettingsStore, "getValue").mockImplementation(
(settingName: string, roomId: string | null = null, excludeDefault = false) => {
if (settingName === "feature_rust_crypto") {
return userIsUsingRust;
}
return originalGetValue(settingName, roomId, excludeDefault);
},
);
const originalGetValueAt = SettingsStore.getValueAt;
jest.spyOn(SettingsStore, "getValueAt").mockImplementation(
(level: SettingLevel, settingName: string) => {
if (settingName === "feature_rust_crypto") {
return newLoginShouldUseRust;
}
// if null we let the original implementation handle it to get the default
if (settingName === "RustCrypto.staged_rollout_percent" && rolloutPercent !== null) {
return rolloutPercent;
}
return originalGetValueAt(level, settingName);
},
);
}
let mockSetValue: jest.SpyInstance;
let mockInitCrypto: jest.SpyInstance;
let mockInitRustCrypto: jest.SpyInstance;
beforeEach(() => {
mockSetValue = jest.spyOn(SettingsStore, "setValue").mockResolvedValue(undefined);
mockInitCrypto = jest.spyOn(testPeg.safeGet(), "initCrypto").mockResolvedValue(undefined);
mockInitRustCrypto = jest.spyOn(testPeg.safeGet(), "initRustCrypto").mockResolvedValue(undefined);
});
it("Should not migrate existing login if rollout is 0", async () => {
mockSettingStore(false, true, 0);
await testPeg.start();
expect(mockInitCrypto).toHaveBeenCalled();
expect(mockInitRustCrypto).not.toHaveBeenCalledTimes(1);
// we should have stashed the setting in the settings store
expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, false);
});
it("Should migrate existing login if rollout is 100", async () => {
mockSettingStore(false, true, 100);
await testPeg.start();
expect(mockInitCrypto).not.toHaveBeenCalled();
expect(mockInitRustCrypto).toHaveBeenCalledTimes(1);
// we should have stashed the setting in the settings store
expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, true);
});
it("Should migrate existing login if user is in rollout bucket", async () => {
mockSettingStore(false, true, 30);
// Use a device id that is known to be in the 30% bucket (hash modulo 100 < 30)
const spy = jest.spyOn(testPeg.get()!, "getDeviceId").mockReturnValue("AAA");
await testPeg.start();
expect(mockInitCrypto).not.toHaveBeenCalled();
expect(mockInitRustCrypto).toHaveBeenCalledTimes(1);
// we should have stashed the setting in the settings store
expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, true);
spy.mockReset();
});
it("Should not migrate existing login if rollout is malformed", async () => {
mockSettingStore(false, true, 100.1);
await testPeg.start();
expect(mockInitCrypto).toHaveBeenCalled();
expect(mockInitRustCrypto).not.toHaveBeenCalledTimes(1);
// we should have stashed the setting in the settings store
expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, false);
});
it("Default is to not migrate", async () => {
mockSettingStore(false, true, null);
await testPeg.start();
expect(mockInitCrypto).toHaveBeenCalled();
expect(mockInitRustCrypto).not.toHaveBeenCalledTimes(1);
// we should have stashed the setting in the settings store
expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, false);
});
it("Should not migrate if feature_rust_crypto is false", async () => {
mockSettingStore(false, false, 100);
await testPeg.start();
expect(mockInitCrypto).toHaveBeenCalled();
expect(mockInitRustCrypto).not.toHaveBeenCalledTimes(1);
// we should have stashed the setting in the settings store
expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, false);
});
});
it("should reload when store database closes for a guest user", async () => {
testPeg.safeGet().isGuest = () => true;
const emitter = new EventEmitter();

View File

@ -0,0 +1,89 @@
/*
Copyright 2024 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import { PhasedRolloutFeature } from "../../src/utils/PhasedRolloutFeature";
describe("Test PhasedRolloutFeature", () => {
function randomUserId() {
const characters = "abcdefghijklmnopqrstuvwxyz0123456789.=_-/+";
let result = "";
const charactersLength = characters.length;
const idLength = Math.floor(Math.random() * 15) + 6; // Random number between 6 and 20
for (let i = 0; i < idLength; i++) {
result += characters.charAt(Math.floor(Math.random() * charactersLength));
}
return "@" + result + ":matrix.org";
}
function randomDeviceId() {
const characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
let result = "";
const charactersLength = characters.length;
for (let i = 0; i < 10; i++) {
result += characters.charAt(Math.floor(Math.random() * charactersLength));
}
return result;
}
it("should only accept valid percentage", () => {
expect(() => new PhasedRolloutFeature("test", 0.8)).toThrow();
expect(() => new PhasedRolloutFeature("test", -1)).toThrow();
expect(() => new PhasedRolloutFeature("test", 123)).toThrow();
});
it("should enable for all if percentage is 100", () => {
const phasedRolloutFeature = new PhasedRolloutFeature("test", 100);
for (let i = 0; i < 1000; i++) {
expect(phasedRolloutFeature.isFeatureEnabled(randomUserId())).toBeTruthy();
}
});
it("should not enable for anyone if percentage is 0", () => {
const phasedRolloutFeature = new PhasedRolloutFeature("test", 0);
for (let i = 0; i < 1000; i++) {
expect(phasedRolloutFeature.isFeatureEnabled(randomUserId())).toBeFalsy();
}
});
it("should enable for more users if percentage grows", () => {
let rolloutPercentage = 0;
let previousBatch: string[] = [];
const allUsers = new Array(1000).fill(0).map(() => randomDeviceId());
while (rolloutPercentage <= 90) {
rolloutPercentage += 10;
const nextRollout = new PhasedRolloutFeature("test", rolloutPercentage);
const nextBatch = allUsers.filter((userId) => nextRollout.isFeatureEnabled(userId));
expect(previousBatch.length).toBeLessThan(nextBatch.length);
expect(previousBatch.every((user) => nextBatch.includes(user))).toBeTruthy();
previousBatch = nextBatch;
}
});
it("should distribute differently depending on the feature name", () => {
const allUsers = new Array(1000).fill(0).map(() => randomUserId());
const featureARollout = new PhasedRolloutFeature("FeatureA", 50);
const featureBRollout = new PhasedRolloutFeature("FeatureB", 50);
const featureAUsers = allUsers.filter((userId) => featureARollout.isFeatureEnabled(userId));
const featureBUsers = allUsers.filter((userId) => featureBRollout.isFeatureEnabled(userId));
expect(featureAUsers).not.toEqual(featureBUsers);
});
});

View File

@ -6527,6 +6527,11 @@ jest@^29.6.2:
resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499"
integrity sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==
js-xxhash@^3.0.1:
version "3.0.1"
resolved "https://registry.yarnpkg.com/js-xxhash/-/js-xxhash-3.0.1.tgz#e093b53d02cd80a830d61f58290c206aaa877b24"
integrity sha512-Y2NSC77RIxJrvi2NoXjMi2LYsVDTlVqBoQRi8PXQg4PtP29wdtIOhsp8Ujw4EjEkBFheCPx8bMOmI9zoxx/3jQ==
js-yaml@^3.13.1:
version "3.14.1"
resolved "https://registry.yarnpkg.com/js-yaml/-/js-yaml-3.14.1.tgz#dae812fdb3825fa306609a8717383c50c36a0537"