diff --git a/spec/unit/rust-crypto/RoomEncryptor.spec.ts b/spec/unit/rust-crypto/RoomEncryptor.spec.ts index de6863cbce1..9094ef38e23 100644 --- a/spec/unit/rust-crypto/RoomEncryptor.spec.ts +++ b/spec/unit/rust-crypto/RoomEncryptor.spec.ts @@ -52,6 +52,11 @@ describe("RoomEncryptor", () => { let mockOutgoingRequestManager: Mocked; let mockRoom: Mocked; + const mockRoomMember = { + userId: "@alice:example.org", + membership: "join", + } as unknown as Mocked; + function createMockEvent(text: string): Mocked { return { getTxnId: jest.fn().mockReturnValue(""), @@ -87,11 +92,6 @@ describe("RoomEncryptor", () => { doProcessOutgoingRequests: jest.fn().mockResolvedValue(undefined), } as unknown as Mocked; - const mockRoomMember = { - userId: "@alice:example.org", - membership: "join", - } as unknown as Mocked; - mockRoom = { roomId: "!foo:example.org", getJoinedMembers: jest.fn().mockReturnValue([mockRoomMember]), @@ -136,5 +136,49 @@ describe("RoomEncryptor", () => { // should have been called again expect(mockOlmMachine.shareRoomKey).toHaveBeenCalledTimes(6); }); + + // Regression test for https://github.com/element-hq/element-web/issues/26684 + it("Should maintain order of encryption requests", async () => { + const firstTargetMembers = defer(); + const secondTargetMembers = defer(); + + mockOlmMachine.shareRoomKey.mockResolvedValue(undefined); + + // Hook into this method to demonstrate the race condition + mockRoom.getEncryptionTargetMembers + .mockImplementationOnce(async () => { + await firstTargetMembers.promise; + return [mockRoomMember]; + }) + .mockImplementationOnce(async () => { + await secondTargetMembers.promise; + return [mockRoomMember]; + }); + + let firstMessageFinished: string | null = null; + + const firstRequest = roomEncryptor.encryptEvent(createMockEvent("Hello"), false); + const secondRequest = roomEncryptor.encryptEvent(createMockEvent("Edit of Hello"), false); + + firstRequest.then(() => { + if (firstMessageFinished === null) { + firstMessageFinished = "hello"; + } + }); + + secondRequest.then(() => { + if (firstMessageFinished === null) { + firstMessageFinished = "edit"; + } + }); + + // suppose the second getEncryptionTargetMembers call returns first + secondTargetMembers.resolve(); + firstTargetMembers.resolve(); + + await Promise.all([firstRequest, secondRequest]); + + expect(firstMessageFinished).toBe("hello"); + }); }); }); diff --git a/src/logger.ts b/src/logger.ts index 4ea470d61a0..7783963326f 100644 --- a/src/logger.ts +++ b/src/logger.ts @@ -29,7 +29,7 @@ export interface Logger extends BaseLogger { } /** The basic interface for a logger which doesn't support children */ -interface BaseLogger { +export interface BaseLogger { /** * Output trace message to the logger, with stack trace. * diff --git a/src/rust-crypto/RoomEncryptor.ts b/src/rust-crypto/RoomEncryptor.ts index 5de1769ba50..ef3ea777fb7 100644 --- a/src/rust-crypto/RoomEncryptor.ts +++ b/src/rust-crypto/RoomEncryptor.ts @@ -46,8 +46,12 @@ export class RoomEncryptor { /** whether the room members have been loaded and tracked for the first time */ private lazyLoadedMembersResolved = false; - /** Ensures that there is only one call to shareRoomKeys at a time */ - private currentShareRoomKeyPromise = Promise.resolve(); + /** + * Ensures that there is only one encryption operation at a time for that room. + * + * An encryption operation is either a {@link prepareForEncryption} or an {@link encryptEvent} call. + */ + private currentEncryptionPromise: Promise = Promise.resolve(); /** * @param olmMachine - The rust-sdk's OlmMachine @@ -118,11 +122,47 @@ export class RoomEncryptor { * @param globalBlacklistUnverifiedDevices - When `true`, it will not send encrypted messages to unverified devices */ public async prepareForEncryption(globalBlacklistUnverifiedDevices: boolean): Promise { - const logger = new LogSpan(this.prefixedLogger, "prepareForEncryption"); + // We consider a prepareForEncryption as an encryption promise as it will potentially share keys + // even if it doesn't send an event. + // Usually this is called when the user starts typing, so we want to make sure we have keys ready when the + // message is finally sent. + // If `encryptEvent` is invoked before `prepareForEncryption` has completed, the `encryptEvent` call will wait for + // `prepareForEncryption` to complete before executing. + // The part where `encryptEvent` shares the room key will then usually be a no-op as it was already performed by `prepareForEncryption`. + await this.encryptEvent(null, globalBlacklistUnverifiedDevices); + } - await logDuration(this.prefixedLogger, "prepareForEncryption", async () => { - await this.ensureEncryptionSession(logger, globalBlacklistUnverifiedDevices); - }); + /** + * Encrypt an event for this room, or prepare for encryption. + * + * This will ensure that we have a megolm session for this room, share it with the devices in the room, and + * then, if an event is provided, encrypt it using the session. + * + * @param event - Event to be encrypted, or null if only preparing for encryption (in which case we will pre-share the room key). + * @param globalBlacklistUnverifiedDevices - When `true`, it will not send encrypted messages to unverified devices + */ + public encryptEvent(event: MatrixEvent | null, globalBlacklistUnverifiedDevices: boolean): Promise { + const logger = new LogSpan(this.prefixedLogger, event ? event.getTxnId() ?? "" : "prepareForEncryption"); + // Ensure order of encryption to avoid message ordering issues, as the scheduler only ensures + // events order after they have been encrypted. + const prom = this.currentEncryptionPromise + .catch(() => { + // Any errors in the previous call will have been reported already, so there is nothing to do here. + // we just throw away the error and start anew. + }) + .then(async () => { + await logDuration(logger, "ensureEncryptionSession", async () => { + await this.ensureEncryptionSession(logger, globalBlacklistUnverifiedDevices); + }); + if (event) { + await logDuration(logger, "encryptEventInner", async () => { + await this.encryptEventInner(logger, event); + }); + } + }); + + this.currentEncryptionPromise = prom; + return prom; } /** @@ -151,7 +191,9 @@ export class RoomEncryptor { // This could end up being racy (if two calls to ensureEncryptionSession happen at the same time), but that's // not a particular problem, since `OlmMachine.updateTrackedUsers` just adds any users that weren't already tracked. if (!this.lazyLoadedMembersResolved) { - await this.olmMachine.updateTrackedUsers(members.map((u) => new RustSdkCryptoJs.UserId(u.userId))); + await logDuration(this.prefixedLogger, "loadMembersIfNeeded: updateTrackedUsers", async () => { + await this.olmMachine.updateTrackedUsers(members.map((u) => new RustSdkCryptoJs.UserId(u.userId))); + }); logger.debug(`Updated tracked users`); this.lazyLoadedMembersResolved = true; @@ -214,7 +256,11 @@ export class RoomEncryptor { this.room.getBlacklistUnverifiedDevices() ?? globalBlacklistUnverifiedDevices; await logDuration(this.prefixedLogger, "shareRoomKey", async () => { - const shareMessages: ToDeviceRequest[] = await this.shareRoomKey(userList, rustEncryptionSettings); + const shareMessages: ToDeviceRequest[] = await this.olmMachine.shareRoomKey( + new RoomId(this.room.roomId), + userList, + rustEncryptionSettings, + ); if (shareMessages) { for (const m of shareMessages) { await this.outgoingRequestManager.outgoingRequestProcessor.makeOutgoingRequest(m); @@ -223,30 +269,6 @@ export class RoomEncryptor { }); } - /** - * The Rust-SDK requires that we only have one shareRoomKey process in flight at once for a room. - * This method ensures that, by only having one call to shareRoomKey active at once (and making them - * queue up in order). - * - * @param userList - list of userIDs to share with - * @param rustEncryptionSettings - encryption settings to use - * - * @returns a promise which resolves to the list of ToDeviceRequests to send - */ - private async shareRoomKey( - userList: UserId[], - rustEncryptionSettings: EncryptionSettings, - ): Promise { - const prom = this.currentShareRoomKeyPromise - .catch(() => { - // any errors in the previous claim will have been reported already, so there is nothing to do here. - // we just throw away the error and start anew. - }) - .then(() => this.olmMachine.shareRoomKey(new RoomId(this.room.roomId), userList, rustEncryptionSettings)); - this.currentShareRoomKeyPromise = prom; - return prom; - } - /** * Discard any existing group session for this room */ @@ -257,19 +279,7 @@ export class RoomEncryptor { } } - /** - * Encrypt an event for this room - * - * This will ensure that we have a megolm session for this room, share it with the devices in the room, and - * then encrypt the event using the session. - * - * @param event - Event to be encrypted. - * @param globalBlacklistUnverifiedDevices - When `true`, it will not send encrypted messages to unverified devices - */ - public async encryptEvent(event: MatrixEvent, globalBlacklistUnverifiedDevices: boolean): Promise { - const logger = new LogSpan(this.prefixedLogger, event.getTxnId() ?? ""); - await this.ensureEncryptionSession(logger, globalBlacklistUnverifiedDevices); - + private async encryptEventInner(logger: LogSpan, event: MatrixEvent): Promise { logger.debug("Encrypting actual message content"); const encryptedContent = await this.olmMachine.encryptRoomEvent( new RoomId(this.room.roomId), diff --git a/src/utils.ts b/src/utils.ts index a40aac1125e..28e867eeda0 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -25,7 +25,7 @@ import { Optional } from "matrix-events-sdk"; import { IEvent, MatrixEvent } from "./models/event"; import { M_TIMESTAMP } from "./@types/location"; import { ReceiptType } from "./@types/read_receipts"; -import { Logger } from "./logger"; +import { BaseLogger } from "./logger"; const interns = new Map(); @@ -395,7 +395,7 @@ export function sleep(ms: number, value?: T): Promise { * @param name - The name of the operation. * @param block - The block to execute. */ -export async function logDuration(logger: Logger, name: string, block: () => Promise): Promise { +export async function logDuration(logger: BaseLogger, name: string, block: () => Promise): Promise { const start = Date.now(); try { return await block();