252 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			TypeScript
		
	
	
			
		
		
	
	
			252 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			TypeScript
		
	
	
| /*
 | |
| Copyright 2018 - 2021 The Matrix.org Foundation C.I.C.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License.
 | |
| */
 | |
| 
 | |
| import { DecryptionError } from "matrix-js-sdk/src/crypto/algorithms";
 | |
| import { MatrixEvent } from "matrix-js-sdk/src/models/event";
 | |
| import { Error as ErrorEvent } from "@matrix-org/analytics-events/types/typescript/Error";
 | |
| 
 | |
| import { PosthogAnalytics } from "./PosthogAnalytics";
 | |
| 
 | |
| export class DecryptionFailure {
 | |
|     public readonly ts: number;
 | |
| 
 | |
|     public constructor(public readonly failedEventId: string, public readonly errorCode: string) {
 | |
|         this.ts = Date.now();
 | |
|     }
 | |
| }
 | |
| 
 | |
| type ErrorCode = "OlmKeysNotSentError" | "OlmIndexError" | "UnknownError" | "OlmUnspecifiedError";
 | |
| 
 | |
| type TrackingFn = (count: number, trackedErrCode: ErrorCode, rawError: string) => void;
 | |
| 
 | |
| export type ErrCodeMapFn = (errcode: string) => ErrorCode;
 | |
| 
 | |
| export class DecryptionFailureTracker {
 | |
|     private static internalInstance = new DecryptionFailureTracker(
 | |
|         (total, errorCode, rawError) => {
 | |
|             for (let i = 0; i < total; i++) {
 | |
|                 PosthogAnalytics.instance.trackEvent<ErrorEvent>({
 | |
|                     eventName: "Error",
 | |
|                     domain: "E2EE",
 | |
|                     name: errorCode,
 | |
|                     context: `mxc_crypto_error_type_${rawError}`,
 | |
|                 });
 | |
|             }
 | |
|         },
 | |
|         (errorCode) => {
 | |
|             // Map JS-SDK error codes to tracker codes for aggregation
 | |
|             switch (errorCode) {
 | |
|                 case "MEGOLM_UNKNOWN_INBOUND_SESSION_ID":
 | |
|                     return "OlmKeysNotSentError";
 | |
|                 case "OLM_UNKNOWN_MESSAGE_INDEX":
 | |
|                     return "OlmIndexError";
 | |
|                 case undefined:
 | |
|                     return "OlmUnspecifiedError";
 | |
|                 default:
 | |
|                     return "UnknownError";
 | |
|             }
 | |
|         },
 | |
|     );
 | |
| 
 | |
|     // Map of event IDs to DecryptionFailure items.
 | |
|     public failures: Map<string, DecryptionFailure> = new Map();
 | |
| 
 | |
|     // Set of event IDs that have been visible to the user.
 | |
|     public visibleEvents: Set<string> = new Set();
 | |
| 
 | |
|     // Map of visible event IDs to `DecryptionFailure`s. Every
 | |
|     // `CHECK_INTERVAL_MS`, this map is checked for failures that
 | |
|     // happened > `GRACE_PERIOD_MS` ago. Those that did are
 | |
|     // accumulated in `failureCounts`.
 | |
|     public visibleFailures: Map<string, DecryptionFailure> = new Map();
 | |
| 
 | |
|     // A histogram of the number of failures that will be tracked at the next tracking
 | |
|     // interval, split by failure error code.
 | |
|     public failureCounts: Record<string, number> = {
 | |
|         // [errorCode]: 42
 | |
|     };
 | |
| 
 | |
|     // Event IDs of failures that were tracked previously
 | |
|     public trackedEvents: Set<string> = new Set();
 | |
| 
 | |
|     // Set to an interval ID when `start` is called
 | |
|     public checkInterval: number | null = null;
 | |
|     public trackInterval: number | null = null;
 | |
| 
 | |
|     // Spread the load on `Analytics` by tracking at a low frequency, `TRACK_INTERVAL_MS`.
 | |
|     public static TRACK_INTERVAL_MS = 60000;
 | |
| 
 | |
|     // Call `checkFailures` every `CHECK_INTERVAL_MS`.
 | |
|     public static CHECK_INTERVAL_MS = 5000;
 | |
| 
 | |
|     // Give events a chance to be decrypted by waiting `GRACE_PERIOD_MS` before counting
 | |
|     // the failure in `failureCounts`.
 | |
|     public static GRACE_PERIOD_MS = 4000;
 | |
| 
 | |
|     /**
 | |
|      * Create a new DecryptionFailureTracker.
 | |
|      *
 | |
|      * Call `eventDecrypted(event, err)` on this instance when an event is decrypted.
 | |
|      *
 | |
|      * Call `start()` to start the tracker, and `stop()` to stop tracking.
 | |
|      *
 | |
|      * @param {function} fn The tracking function, which will be called when failures
 | |
|      * are tracked. The function should have a signature `(count, trackedErrorCode) => {...}`,
 | |
|      * where `count` is the number of failures and `errorCode` matches the `.code` of
 | |
|      * provided DecryptionError errors (by default, unless `errorCodeMapFn` is specified.
 | |
|      * @param {function?} errorCodeMapFn The function used to map error codes to the
 | |
|      * trackedErrorCode. If not provided, the `.code` of errors will be used.
 | |
|      */
 | |
|     private constructor(private readonly fn: TrackingFn, private readonly errorCodeMapFn: ErrCodeMapFn) {
 | |
|         if (!fn || typeof fn !== "function") {
 | |
|             throw new Error("DecryptionFailureTracker requires tracking function");
 | |
|         }
 | |
| 
 | |
|         if (typeof errorCodeMapFn !== "function") {
 | |
|             throw new Error("DecryptionFailureTracker second constructor argument should be a function");
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     public static get instance(): DecryptionFailureTracker {
 | |
|         return DecryptionFailureTracker.internalInstance;
 | |
|     }
 | |
| 
 | |
|     // loadTrackedEvents() {
 | |
|     //     this.trackedEvents = new Set(JSON.parse(localStorage.getItem('mx-decryption-failure-event-ids')) || []);
 | |
|     // }
 | |
| 
 | |
|     // saveTrackedEvents() {
 | |
|     //     localStorage.setItem('mx-decryption-failure-event-ids', JSON.stringify([...this.trackedEvents]));
 | |
|     // }
 | |
| 
 | |
|     public eventDecrypted(e: MatrixEvent, err: DecryptionError): void {
 | |
|         // for now we only track megolm decrytion failures
 | |
|         if (e.getWireContent().algorithm != "m.megolm.v1.aes-sha2") {
 | |
|             return;
 | |
|         }
 | |
|         if (err) {
 | |
|             this.addDecryptionFailure(new DecryptionFailure(e.getId()!, err.code));
 | |
|         } else {
 | |
|             // Could be an event in the failures, remove it
 | |
|             this.removeDecryptionFailuresForEvent(e);
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     public addVisibleEvent(e: MatrixEvent): void {
 | |
|         const eventId = e.getId()!;
 | |
| 
 | |
|         if (this.trackedEvents.has(eventId)) {
 | |
|             return;
 | |
|         }
 | |
| 
 | |
|         this.visibleEvents.add(eventId);
 | |
|         if (this.failures.has(eventId) && !this.visibleFailures.has(eventId)) {
 | |
|             this.visibleFailures.set(eventId, this.failures.get(eventId)!);
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     public addDecryptionFailure(failure: DecryptionFailure): void {
 | |
|         const eventId = failure.failedEventId;
 | |
| 
 | |
|         if (this.trackedEvents.has(eventId)) {
 | |
|             return;
 | |
|         }
 | |
| 
 | |
|         this.failures.set(eventId, failure);
 | |
|         if (this.visibleEvents.has(eventId) && !this.visibleFailures.has(eventId)) {
 | |
|             this.visibleFailures.set(eventId, failure);
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     public removeDecryptionFailuresForEvent(e: MatrixEvent): void {
 | |
|         const eventId = e.getId()!;
 | |
|         this.failures.delete(eventId);
 | |
|         this.visibleFailures.delete(eventId);
 | |
|     }
 | |
| 
 | |
|     /**
 | |
|      * Start checking for and tracking failures.
 | |
|      */
 | |
|     public start(): void {
 | |
|         this.checkInterval = window.setInterval(
 | |
|             () => this.checkFailures(Date.now()),
 | |
|             DecryptionFailureTracker.CHECK_INTERVAL_MS,
 | |
|         );
 | |
| 
 | |
|         this.trackInterval = window.setInterval(() => this.trackFailures(), DecryptionFailureTracker.TRACK_INTERVAL_MS);
 | |
|     }
 | |
| 
 | |
|     /**
 | |
|      * Clear state and stop checking for and tracking failures.
 | |
|      */
 | |
|     public stop(): void {
 | |
|         if (this.checkInterval) clearInterval(this.checkInterval);
 | |
|         if (this.trackInterval) clearInterval(this.trackInterval);
 | |
| 
 | |
|         this.failures = new Map();
 | |
|         this.visibleEvents = new Set();
 | |
|         this.visibleFailures = new Map();
 | |
|         this.failureCounts = {};
 | |
|     }
 | |
| 
 | |
|     /**
 | |
|      * Mark failures that occurred before nowTs - GRACE_PERIOD_MS as failures that should be
 | |
|      * tracked. Only mark one failure per event ID.
 | |
|      * @param {number} nowTs the timestamp that represents the time now.
 | |
|      */
 | |
|     public checkFailures(nowTs: number): void {
 | |
|         const failuresGivenGrace: Set<DecryptionFailure> = new Set();
 | |
|         const failuresNotReady: Map<string, DecryptionFailure> = new Map();
 | |
|         for (const [eventId, failure] of this.visibleFailures) {
 | |
|             if (nowTs > failure.ts + DecryptionFailureTracker.GRACE_PERIOD_MS) {
 | |
|                 failuresGivenGrace.add(failure);
 | |
|                 this.trackedEvents.add(eventId);
 | |
|             } else {
 | |
|                 failuresNotReady.set(eventId, failure);
 | |
|             }
 | |
|         }
 | |
|         this.visibleFailures = failuresNotReady;
 | |
| 
 | |
|         // Commented out for now for expediency, we need to consider unbound nature of storing
 | |
|         // this in localStorage
 | |
|         // this.saveTrackedEvents();
 | |
| 
 | |
|         this.aggregateFailures(failuresGivenGrace);
 | |
|     }
 | |
| 
 | |
|     private aggregateFailures(failures: Set<DecryptionFailure>): void {
 | |
|         for (const failure of failures) {
 | |
|             const errorCode = failure.errorCode;
 | |
|             this.failureCounts[errorCode] = (this.failureCounts[errorCode] || 0) + 1;
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     /**
 | |
|      * If there are failures that should be tracked, call the given trackDecryptionFailure
 | |
|      * function with the number of failures that should be tracked.
 | |
|      */
 | |
|     public trackFailures(): void {
 | |
|         for (const errorCode of Object.keys(this.failureCounts)) {
 | |
|             if (this.failureCounts[errorCode] > 0) {
 | |
|                 const trackedErrorCode = this.errorCodeMapFn(errorCode);
 | |
| 
 | |
|                 this.fn(this.failureCounts[errorCode], trackedErrorCode, errorCode);
 | |
|                 this.failureCounts[errorCode] = 0;
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| }
 |