206 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			TypeScript
		
	
	
			
		
		
	
	
			206 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			TypeScript
		
	
	
| /*
 | |
| Copyright 2018 - 2021 The Matrix.org Foundation C.I.C.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License.
 | |
| */
 | |
| 
 | |
| import { MatrixError } from "matrix-js-sdk/src/http-api";
 | |
| import { MatrixEvent } from "matrix-js-sdk/src/models/event";
 | |
| 
 | |
| export class DecryptionFailure {
 | |
|     public readonly ts: number;
 | |
| 
 | |
|     constructor(public readonly failedEventId: string, public readonly errorCode: string) {
 | |
|         this.ts = Date.now();
 | |
|     }
 | |
| }
 | |
| 
 | |
| type TrackingFn = (count: number, trackedErrCode: string) => void;
 | |
| type ErrCodeMapFn = (errcode: string) => string;
 | |
| 
 | |
| export class DecryptionFailureTracker {
 | |
|     // Array of items of type DecryptionFailure. Every `CHECK_INTERVAL_MS`, this list
 | |
|     // is checked for failures that happened > `GRACE_PERIOD_MS` ago. Those that did
 | |
|     // are accumulated in `failureCounts`.
 | |
|     public failures: DecryptionFailure[] = [];
 | |
| 
 | |
|     // A histogram of the number of failures that will be tracked at the next tracking
 | |
|     // interval, split by failure error code.
 | |
|     public failureCounts: Record<string, number> = {
 | |
|         // [errorCode]: 42
 | |
|     };
 | |
| 
 | |
|     // Event IDs of failures that were tracked previously
 | |
|     public trackedEventHashMap: Record<string, boolean> = {
 | |
|         // [eventId]: true
 | |
|     };
 | |
| 
 | |
|     // Set to an interval ID when `start` is called
 | |
|     public checkInterval: NodeJS.Timeout = null;
 | |
|     public trackInterval: NodeJS.Timeout = null;
 | |
| 
 | |
|     // Spread the load on `Analytics` by tracking at a low frequency, `TRACK_INTERVAL_MS`.
 | |
|     static TRACK_INTERVAL_MS = 60000;
 | |
| 
 | |
|     // Call `checkFailures` every `CHECK_INTERVAL_MS`.
 | |
|     static CHECK_INTERVAL_MS = 5000;
 | |
| 
 | |
|     // Give events a chance to be decrypted by waiting `GRACE_PERIOD_MS` before counting
 | |
|     // the failure in `failureCounts`.
 | |
|     static GRACE_PERIOD_MS = 60000;
 | |
| 
 | |
|     /**
 | |
|      * Create a new DecryptionFailureTracker.
 | |
|      *
 | |
|      * Call `eventDecrypted(event, err)` on this instance when an event is decrypted.
 | |
|      *
 | |
|      * Call `start()` to start the tracker, and `stop()` to stop tracking.
 | |
|      *
 | |
|      * @param {function} fn The tracking function, which will be called when failures
 | |
|      * are tracked. The function should have a signature `(count, trackedErrorCode) => {...}`,
 | |
|      * where `count` is the number of failures and `errorCode` matches the `.code` of
 | |
|      * provided DecryptionError errors (by default, unless `errorCodeMapFn` is specified.
 | |
|      * @param {function?} errorCodeMapFn The function used to map error codes to the
 | |
|      * trackedErrorCode. If not provided, the `.code` of errors will be used.
 | |
|      */
 | |
|     constructor(private readonly fn: TrackingFn, private readonly errorCodeMapFn?: ErrCodeMapFn) {
 | |
|         if (!fn || typeof fn !== 'function') {
 | |
|             throw new Error('DecryptionFailureTracker requires tracking function');
 | |
|         }
 | |
| 
 | |
|         if (errorCodeMapFn && typeof errorCodeMapFn !== 'function') {
 | |
|             throw new Error('DecryptionFailureTracker second constructor argument should be a function');
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     // loadTrackedEventHashMap() {
 | |
|     //     this.trackedEventHashMap = JSON.parse(localStorage.getItem('mx-decryption-failure-event-id-hashes')) || {};
 | |
|     // }
 | |
| 
 | |
|     // saveTrackedEventHashMap() {
 | |
|     //     localStorage.setItem('mx-decryption-failure-event-id-hashes', JSON.stringify(this.trackedEventHashMap));
 | |
|     // }
 | |
| 
 | |
|     public eventDecrypted(e: MatrixEvent, err: MatrixError | Error): void {
 | |
|         if (err) {
 | |
|             this.addDecryptionFailure(new DecryptionFailure(e.getId(), err.code));
 | |
|         } else {
 | |
|             // Could be an event in the failures, remove it
 | |
|             this.removeDecryptionFailuresForEvent(e);
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     public addDecryptionFailure(failure: DecryptionFailure): void {
 | |
|         this.failures.push(failure);
 | |
|     }
 | |
| 
 | |
|     public removeDecryptionFailuresForEvent(e: MatrixEvent): void {
 | |
|         this.failures = this.failures.filter((f) => f.failedEventId !== e.getId());
 | |
|     }
 | |
| 
 | |
|     /**
 | |
|      * Start checking for and tracking failures.
 | |
|      */
 | |
|     public start(): void {
 | |
|         this.checkInterval = setInterval(
 | |
|             () => this.checkFailures(Date.now()),
 | |
|             DecryptionFailureTracker.CHECK_INTERVAL_MS,
 | |
|         );
 | |
| 
 | |
|         this.trackInterval = setInterval(
 | |
|             () => this.trackFailures(),
 | |
|             DecryptionFailureTracker.TRACK_INTERVAL_MS,
 | |
|         );
 | |
|     }
 | |
| 
 | |
|     /**
 | |
|      * Clear state and stop checking for and tracking failures.
 | |
|      */
 | |
|     public stop(): void {
 | |
|         clearInterval(this.checkInterval);
 | |
|         clearInterval(this.trackInterval);
 | |
| 
 | |
|         this.failures = [];
 | |
|         this.failureCounts = {};
 | |
|     }
 | |
| 
 | |
|     /**
 | |
|      * Mark failures that occurred before nowTs - GRACE_PERIOD_MS as failures that should be
 | |
|      * tracked. Only mark one failure per event ID.
 | |
|      * @param {number} nowTs the timestamp that represents the time now.
 | |
|      */
 | |
|     public checkFailures(nowTs: number): void {
 | |
|         const failuresGivenGrace = [];
 | |
|         const failuresNotReady = [];
 | |
|         while (this.failures.length > 0) {
 | |
|             const f = this.failures.shift();
 | |
|             if (nowTs > f.ts + DecryptionFailureTracker.GRACE_PERIOD_MS) {
 | |
|                 failuresGivenGrace.push(f);
 | |
|             } else {
 | |
|                 failuresNotReady.push(f);
 | |
|             }
 | |
|         }
 | |
|         this.failures = failuresNotReady;
 | |
| 
 | |
|         // Only track one failure per event
 | |
|         const dedupedFailuresMap = failuresGivenGrace.reduce(
 | |
|             (map, failure) => {
 | |
|                 if (!this.trackedEventHashMap[failure.failedEventId]) {
 | |
|                     return map.set(failure.failedEventId, failure);
 | |
|                 } else {
 | |
|                     return map;
 | |
|                 }
 | |
|             },
 | |
|             // Use a map to preseve key ordering
 | |
|             new Map(),
 | |
|         );
 | |
| 
 | |
|         const trackedEventIds = [...dedupedFailuresMap.keys()];
 | |
| 
 | |
|         this.trackedEventHashMap = trackedEventIds.reduce(
 | |
|             (result, eventId) => ({ ...result, [eventId]: true }),
 | |
|             this.trackedEventHashMap,
 | |
|         );
 | |
| 
 | |
|         // Commented out for now for expediency, we need to consider unbound nature of storing
 | |
|         // this in localStorage
 | |
|         // this.saveTrackedEventHashMap();
 | |
| 
 | |
|         const dedupedFailures = dedupedFailuresMap.values();
 | |
| 
 | |
|         this.aggregateFailures(dedupedFailures);
 | |
|     }
 | |
| 
 | |
|     private aggregateFailures(failures: DecryptionFailure[]): void {
 | |
|         for (const failure of failures) {
 | |
|             const errorCode = failure.errorCode;
 | |
|             this.failureCounts[errorCode] = (this.failureCounts[errorCode] || 0) + 1;
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     /**
 | |
|      * If there are failures that should be tracked, call the given trackDecryptionFailure
 | |
|      * function with the number of failures that should be tracked.
 | |
|      */
 | |
|     public trackFailures(): void {
 | |
|         for (const errorCode of Object.keys(this.failureCounts)) {
 | |
|             if (this.failureCounts[errorCode] > 0) {
 | |
|                 const trackedErrorCode = this.errorCodeMapFn ? this.errorCodeMapFn(errorCode) : errorCode;
 | |
| 
 | |
|                 this.fn(this.failureCounts[errorCode], trackedErrorCode);
 | |
|                 this.failureCounts[errorCode] = 0;
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| }
 |