diff --git a/Sources/DXProtocol/Session/Session.swift b/Sources/DXProtocol/Session/Session.swift index 36981d5..17244d3 100644 --- a/Sources/DXProtocol/Session/Session.swift +++ b/Sources/DXProtocol/Session/Session.swift @@ -231,20 +231,24 @@ public struct Session: Codable { /// - identityStore: The identity store. /// /// - Returns: The encrypted message container. - public mutating func encrypt(data: Data, - for address: ProtocolAddress, - sessionStore: SessionStorable, - identityStore: IdentityKeyStorable) throws -> MessageContainer { + public static func encrypt(data: Data, + for address: ProtocolAddress, + sessionStore: SessionStorable, + identityStore: IdentityKeyStorable) throws -> MessageContainer { let lock = SessionLock(address: address) lock.lock() defer { lock.unlock() } - let result = try self.state.encrypt( - data: data, - sessionStore: sessionStore, - identityStore: identityStore) - try sessionStore.storeSession(self, for: address) - + guard var session = try sessionStore.loadSession(for: address) else { + throw DXError.sessionNotFound("Failed to find session while encrypting message") + } + + let result = try session.state.encrypt( + data: data, + sessionStore: sessionStore, + identityStore: identityStore) + try sessionStore.storeSession(session, for: address) + return result } @@ -367,7 +371,6 @@ extension Session { lock.lock() defer { lock.unlock() } - // This code is not covered by tests guard var session = try sessionStore.loadSession(for: address) else { throw DXError.sessionNotFound("Failed to find session while decrypting message") } diff --git a/Tests/DXProtocolTests/SessionTests/SessionTests.swift b/Tests/DXProtocolTests/SessionTests/SessionTests.swift index b627c07..738792b 100644 --- a/Tests/DXProtocolTests/SessionTests/SessionTests.swift +++ b/Tests/DXProtocolTests/SessionTests/SessionTests.swift @@ -118,6 +118,8 @@ final class SessionTests: XCTestCase { func testInitializeSessionOptionalOneTimePreKey() throws { let senderClient = try TestClient(userId: UUID()) // Alice let recipientClient = try TestClient(userId: UUID()) // Bob + let aliceAddress = senderClient.protocolAddress + let bobAddress = recipientClient.protocolAddress // Generate identity information let bobIdentityKeyPair = try recipientClient.identityKeyStore.identityKeyPair() @@ -141,24 +143,24 @@ final class SessionTests: XCTestCase { oneTimePreKey: nil) // Alice processes the bundle: - var aliceSession = try Session.processPreKeyBundle( + try Session.processPreKeyBundle( bobBundle, - for: recipientClient.protocolAddress, + for: bobAddress, sessionStore: senderClient.sessionStore, identityStore: senderClient.identityKeyStore) // Alice creates the first message (Pre Key message) let initialMessageData = try XCTUnwrap("Optional OneTime PreKey".data(using: .utf8)) - let aliceMessage = try aliceSession.encrypt( + let aliceMessage = try Session.encrypt( data: initialMessageData, - for: recipientClient.protocolAddress, + for: bobAddress, sessionStore: senderClient.sessionStore, identityStore: senderClient.identityKeyStore) // Bob decrypts the first message (Pre Key message) from Alice var result = try Session.decrypt( message: aliceMessage, - from: senderClient.protocolAddress, + from: aliceAddress, sessionStore: recipientClient.sessionStore, identityStore: recipientClient.identityKeyStore, preKeyStore: recipientClient.preKeyStore, @@ -166,19 +168,17 @@ final class SessionTests: XCTestCase { XCTAssertEqual(result, initialMessageData) // Finally, Bob sends a message back to acknowledge the pre-key. - let aliceAddress = senderClient.protocolAddress - var bobSession = try XCTUnwrap(try recipientClient.sessionStore.loadSession(for: aliceAddress)) let bobReplyData = try XCTUnwrap("Reply Optional OneTime PreKey".data(using: .utf8)) - let bobMessage = try bobSession.encrypt( + let bobMessage = try Session.encrypt( data: bobReplyData, - for: senderClient.protocolAddress, + for: aliceAddress, sessionStore: recipientClient.sessionStore, identityStore: recipientClient.identityKeyStore) // Alice decrypts first message from Bob (with acknowledge of the pre-key) result = try Session.decrypt( message: bobMessage, - from: recipientClient.protocolAddress, + from: bobAddress, sessionStore: senderClient.sessionStore, identityStore: senderClient.identityKeyStore, preKeyStore: senderClient.preKeyStore, @@ -196,11 +196,10 @@ final class SessionTests: XCTestCase { // Alice sends the second message (may be after some time so load session from storage) let bobAddress = recipientClient.protocolAddress - var aliceSession1 = try XCTUnwrap(try senderClient.sessionStore.loadSession(for: bobAddress)) let secondMessageString = "Those who stands for nothing will fall for anything" - let secondMessage = try aliceSession1.encrypt( + let secondMessage = try Session.encrypt( data: try XCTUnwrap(secondMessageString.data(using: .utf8)), - for: recipientClient.protocolAddress, + for: bobAddress, sessionStore: senderClient.sessionStore, identityStore: senderClient.identityKeyStore) @@ -227,25 +226,25 @@ final class SessionTests: XCTestCase { // Alice sends the two messages in a row let bobAddress = recipientClient.protocolAddress - var aliceSession1 = try XCTUnwrap(try senderClient.sessionStore.loadSession(for: bobAddress)) let plaintext1 = "Those who stands for nothing will fall for anything" - let cipherMessage1 = try aliceSession1.encrypt( + let cipherMessage1 = try Session.encrypt( data: try XCTUnwrap(plaintext1.data(using: .utf8)), - for: recipientClient.protocolAddress, + for: bobAddress, sessionStore: senderClient.sessionStore, identityStore: senderClient.identityKeyStore) let plaintext2 = "Do not despair when your enemy attacks you." - let cipherMessage2 = try aliceSession1.encrypt( + let cipherMessage2 = try Session.encrypt( data: try XCTUnwrap(plaintext2.data(using: .utf8)), - for: recipientClient.protocolAddress, + for: bobAddress, sessionStore: senderClient.sessionStore, identityStore: senderClient.identityKeyStore) // Bob decrypts these messages from Alice + let aliceAddress = senderClient.protocolAddress let decrypted1 = try Session.decrypt( message: cipherMessage1, - from: senderClient.protocolAddress, + from: aliceAddress, sessionStore: recipientClient.sessionStore, identityStore: recipientClient.identityKeyStore, preKeyStore: recipientClient.preKeyStore, @@ -253,7 +252,7 @@ final class SessionTests: XCTestCase { let decrypted2 = try Session.decrypt( message: cipherMessage2, - from: senderClient.protocolAddress, + from: aliceAddress, sessionStore: recipientClient.sessionStore, identityStore: recipientClient.identityKeyStore, preKeyStore: recipientClient.preKeyStore, @@ -265,27 +264,25 @@ final class SessionTests: XCTestCase { let decryptedText2 = try XCTUnwrap(String(data: decrypted2, encoding: .utf8)) XCTAssertEqual(plaintext2, decryptedText2) } - + func testThreadSafeSimultaneousDecrypt() async throws { let senderClient = try TestClient(userId: UUID()) // Alice let recipientClient = try TestClient(userId: UUID()) // Bob - let recipientAddress = recipientClient.protocolAddress - + let bobAddress = recipientClient.protocolAddress + let aliceAddress = senderClient.protocolAddress + try initializeSession(senderClient: senderClient, recipientClient: recipientClient) // Alice creates a set of messages let aliceMessageCount = 50 var aliceMessages: [(Data, MessageContainer)] = [] for index in 0..]() + for index in 0..<5 { + let task = Task { + let plaintext = "Message \(index)" + let data = try XCTUnwrap(plaintext.data(using: .utf8)) + let encrypted = try Session.encrypt( + data: data, + for: bobAddress, + sessionStore: senderClient.sessionStore, + identityStore: senderClient.identityKeyStore) + return (encrypted, data) + } + tasks.append(task) } - let task4 = Task { - var session = try XCTUnwrap(try senderClient.sessionStore.loadSession(for: recipientAddress)) - return try session.encrypt( - data: data, - for: recipientClient.protocolAddress, - sessionStore: senderClient.sessionStore, - identityStore: senderClient.identityKeyStore) + var encryptResults = [(encrypted: MessageContainer, expected: Data)]() + for task in tasks { + let result = try await task.value + encryptResults.append(result) } - let task5 = Task { - var session = try XCTUnwrap(try senderClient.sessionStore.loadSession(for: recipientAddress)) - return try session.encrypt( - data: data, - for: recipientClient.protocolAddress, - sessionStore: senderClient.sessionStore, - identityStore: senderClient.identityKeyStore) + for result in encryptResults { + let decrypted = try Session.decrypt( + message: result.encrypted, + from: senderClient.protocolAddress, + sessionStore: recipientClient.sessionStore, + identityStore: recipientClient.identityKeyStore, + preKeyStore: recipientClient.preKeyStore, + signedPreKeyStore: recipientClient.signedPreKeyStore) + + XCTAssertEqual(decrypted, result.expected) } - - let message1 = try await task1.value - let message2 = try await task2.value - let message3 = try await task3.value - let message4 = try await task4.value - let message5 = try await task5.value - - let decryptedMessage1 = try Session.decrypt( - message: message1, - from: senderClient.protocolAddress, - sessionStore: recipientClient.sessionStore, - identityStore: recipientClient.identityKeyStore, - preKeyStore: recipientClient.preKeyStore, - signedPreKeyStore: recipientClient.signedPreKeyStore) - - let decryptedMessage2 = try Session.decrypt( - message: message2, - from: senderClient.protocolAddress, - sessionStore: recipientClient.sessionStore, - identityStore: recipientClient.identityKeyStore, - preKeyStore: recipientClient.preKeyStore, - signedPreKeyStore: recipientClient.signedPreKeyStore) - - let decryptedMessage3 = try Session.decrypt( - message: message3, - from: senderClient.protocolAddress, - sessionStore: recipientClient.sessionStore, - identityStore: recipientClient.identityKeyStore, - preKeyStore: recipientClient.preKeyStore, - signedPreKeyStore: recipientClient.signedPreKeyStore) - - let decryptedMessage4 = try Session.decrypt( - message: message4, - from: senderClient.protocolAddress, - sessionStore: recipientClient.sessionStore, - identityStore: recipientClient.identityKeyStore, - preKeyStore: recipientClient.preKeyStore, - signedPreKeyStore: recipientClient.signedPreKeyStore) - - let decryptedMessage5 = try Session.decrypt( - message: message5, - from: senderClient.protocolAddress, - sessionStore: recipientClient.sessionStore, - identityStore: recipientClient.identityKeyStore, - preKeyStore: recipientClient.preKeyStore, - signedPreKeyStore: recipientClient.signedPreKeyStore) - - XCTAssertEqual(decryptedMessage1, data) - XCTAssertEqual(decryptedMessage2, data) - XCTAssertEqual(decryptedMessage3, data) - XCTAssertEqual(decryptedMessage4, data) - XCTAssertEqual(decryptedMessage5, data) } // FIXME: - Need fix @@ -438,25 +370,25 @@ final class SessionTests: XCTestCase { // Alice sends the messages let bobAddress = recipientClient.protocolAddress - var aliceSession1 = try XCTUnwrap(try senderClient.sessionStore.loadSession(for: bobAddress)) let plaintext1 = "Those who stands for nothing will fall for anything" - let cipherMessage1 = try aliceSession1.encrypt( + let cipherMessage1 = try Session.encrypt( data: try XCTUnwrap(plaintext1.data(using: .utf8)), - for: recipientClient.protocolAddress, + for: bobAddress, sessionStore: senderClient.sessionStore, identityStore: senderClient.identityKeyStore) let plaintext2 = "Do not despair when your enemy attacks you." - let cipherMessage2 = try aliceSession1.encrypt( + let cipherMessage2 = try Session.encrypt( data: try XCTUnwrap(plaintext2.data(using: .utf8)), - for: recipientClient.protocolAddress, + for: bobAddress, sessionStore: senderClient.sessionStore, identityStore: senderClient.identityKeyStore) // Bob decrypts second message first + let aliceAddress = senderClient.protocolAddress let decrypted2 = try Session.decrypt( message: cipherMessage2, - from: senderClient.protocolAddress, + from: aliceAddress, sessionStore: recipientClient.sessionStore, identityStore: recipientClient.identityKeyStore, preKeyStore: recipientClient.preKeyStore, @@ -464,7 +396,7 @@ final class SessionTests: XCTestCase { let decrypted1 = try Session.decrypt( message: cipherMessage1, - from: senderClient.protocolAddress, + from: aliceAddress, sessionStore: recipientClient.sessionStore, identityStore: recipientClient.identityKeyStore, preKeyStore: recipientClient.preKeyStore, @@ -480,18 +412,17 @@ final class SessionTests: XCTestCase { func testMessageKeyLimit() throws { let senderClient = try TestClient(userId: UUID()) // Alice let recipientClient = try TestClient(userId: UUID()) // Bob + let bobAddress = recipientClient.protocolAddress + let aliceAddress = senderClient.protocolAddress try initializeSession(senderClient: senderClient, recipientClient: recipientClient) - let bobAddress = recipientClient.protocolAddress - var aliceSession = try XCTUnwrap(try senderClient.sessionStore.loadSession(for: bobAddress)) - // Alice encrypts enough messages to hit messages keys maximum var messages = [MessageContainer]() let count = DXProtocolConstants.messageKeyMaximum + 10 for index in 0..