From 952d834d87da0954a0363065399189bd3af24511 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 5 Jun 2025 00:20:10 -0700 Subject: [PATCH 1/5] Merge changes from tls-channel to prevent accidentally calling SSLEngine.beginHandshake more than once. JAVA-5797 --- .../tlschannel/impl/TlsChannelImpl.java | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java index 3c845ce6d08..7bc10f91795 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java +++ b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java @@ -159,7 +159,9 @@ public TlsChannelImpl( private final Lock readLock = new ReentrantLock(); private final Lock writeLock = new ReentrantLock(); - private volatile boolean negotiated = false; + private boolean handshakeStarted = false; + + private volatile boolean handshakeCompleted = false; /** * Whether a IOException was received from the underlying channel or from the {@link SSLEngine}. @@ -526,14 +528,27 @@ public void handshake() throws IOException { } private void doHandshake(boolean force) throws IOException, EofException { - if (!force && negotiated) return; + if (!force && handshakeCompleted) { + return; + } initLock.lock(); try { if (invalid || shutdownSent) throw new ClosedChannelException(); - if (force || !negotiated) { - engine.beginHandshake(); - LOGGER.trace("Called engine.beginHandshake()"); - handshake(Optional.empty(), Optional.empty()); + if (force || !handshakeCompleted) { + + if (!handshakeStarted) { + engine.beginHandshake(); + LOGGER.trace("Called engine.beginHandshake()"); + handshake(Optional.empty(), Optional.empty()); + + // Some engines that do not support renegotiations may be sensitive to calling + // SSLEngine.beginHandshake() more than once. This guard prevents that. + // See: https://github.com/marianobarrios/tls-channel/issues/197 + handshakeStarted = true; + } + + handshakeCompleted = true; + // call client code try { initSessionCallback.accept(engine.getSession()); @@ -541,7 +556,6 @@ private void doHandshake(boolean force) throws IOException, EofException { LOGGER.trace("client code threw exception in session initialization callback", e); throw new TlsChannelCallbackException("session initialization callback failed", e); } - negotiated = true; } } finally { initLock.unlock(); From a0808f8de1f7e5366be403e53c4f142502741178 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 5 Jun 2025 22:45:29 -0700 Subject: [PATCH 2/5] Perform handshake after marking handshake started. --- .../internal/connection/tlschannel/impl/TlsChannelImpl.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java index 7bc10f91795..20bc69e81f0 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java +++ b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java @@ -539,7 +539,6 @@ private void doHandshake(boolean force) throws IOException, EofException { if (!handshakeStarted) { engine.beginHandshake(); LOGGER.trace("Called engine.beginHandshake()"); - handshake(Optional.empty(), Optional.empty()); // Some engines that do not support renegotiations may be sensitive to calling // SSLEngine.beginHandshake() more than once. This guard prevents that. @@ -547,6 +546,8 @@ private void doHandshake(boolean force) throws IOException, EofException { handshakeStarted = true; } + handshake(Optional.empty(), Optional.empty()); + handshakeCompleted = true; // call client code From 745ca961b63723561ea84e91308d1963ccdaa72a Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 5 Jun 2025 22:58:54 -0700 Subject: [PATCH 3/5] Add integration test. --- .../TlsChannelStreamFunctionalTest.java | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java index 3f80fcddfa3..a3ef129dc41 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java @@ -16,12 +16,17 @@ package com.mongodb.internal.connection; +import com.mongodb.ClusterFixture; import com.mongodb.MongoSocketOpenException; import com.mongodb.ServerAddress; import com.mongodb.connection.SocketSettings; import com.mongodb.connection.SslSettings; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.TimeoutSettings; +import org.bson.ByteBuf; +import org.bson.ByteBufNIO; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.MockedStatic; @@ -29,23 +34,36 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; import java.io.IOException; import java.net.ServerSocket; +import java.nio.ByteBuffer; import java.nio.channels.InterruptedByTimeoutException; import java.nio.channels.SocketChannel; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; +import java.util.Collections; import java.util.concurrent.TimeUnit; +import static com.mongodb.ClusterFixture.getPrimaryServerDescription; +import static com.mongodb.ClusterFixture.sleep; import static com.mongodb.internal.connection.OperationContext.simpleOperationContext; import static java.lang.String.format; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.Assume.assumeTrue; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; class TlsChannelStreamFunctionalTest { private static final SslSettings SSL_SETTINGS = SslSettings.builder().enabled(true).build(); @@ -98,6 +116,7 @@ void shouldEstablishConnection(final int connectTimeoutMs) throws IOException, I try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver()); MockedStatic socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class); ServerSocket serverSocket = new ServerSocket(0, 1)) { + SingleResultSpyCaptor singleResultSpyCaptor = new SingleResultSpyCaptor<>(); socketChannelMockedStatic.when(SocketChannel::open).thenAnswer(singleResultSpyCaptor); @@ -147,4 +166,33 @@ public T answer(final InvocationOnMock invocationOnMock) throws Throwable { private static OperationContext createOperationContext(final int connectTimeoutMs) { return simpleOperationContext(new TimeoutContext(TimeoutSettings.DEFAULT.withConnectTimeoutMS(connectTimeoutMs))); } + + @Test + @DisplayName("should not call beginHandshake more than once during TLS session establishment") + void shouldNotCallBeginHandshakeMoreThenOnceDuringTlsSessionEstablishment() throws IOException, NoSuchAlgorithmException { + assumeTrue(ClusterFixture.getSslSettings().isEnabled()); + + //given + try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver())) { + + SSLContext sslContext = Mockito.spy(SSLContext.getDefault()); + SingleResultSpyCaptor singleResultSpyCaptor = new SingleResultSpyCaptor<>(); + when(sslContext.createSSLEngine(anyString(), anyInt())).thenAnswer(singleResultSpyCaptor); + + StreamFactory streamFactory = streamFactoryFactory.create( + SocketSettings.builder().build(), + SslSettings.builder(ClusterFixture.getSslSettings()) + .context(sslContext) + .build()); + + Stream stream = streamFactory.create(getPrimaryServerDescription().getAddress()); + + stream.open(ClusterFixture.OPERATION_CONTEXT); + ByteBuf wrap = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 3, 4})); + stream.write(Collections.singletonList(wrap), ClusterFixture.OPERATION_CONTEXT); + + sleep(1000); + verify(singleResultSpyCaptor.getResult(), times(1)).beginHandshake(); + } + } } From a7b8855caad24730a94f60b245693dfac0a46d2e Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 5 Jun 2025 23:01:31 -0700 Subject: [PATCH 4/5] Add comments. --- .../TlsChannelStreamFunctionalTest.java | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java index a3ef129dc41..ec0c2df5030 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java @@ -25,6 +25,7 @@ import com.mongodb.internal.TimeoutSettings; import org.bson.ByteBuf; import org.bson.ByteBufNIO; +import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -41,23 +42,21 @@ import java.nio.ByteBuffer; import java.nio.channels.InterruptedByTimeoutException; import java.nio.channels.SocketChannel; -import java.security.NoSuchAlgorithmException; -import java.util.Arrays; import java.util.Collections; import java.util.concurrent.TimeUnit; import static com.mongodb.ClusterFixture.getPrimaryServerDescription; -import static com.mongodb.ClusterFixture.sleep; import static com.mongodb.internal.connection.OperationContext.simpleOperationContext; import static java.lang.String.format; import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static org.junit.Assume.assumeTrue; +import static java.util.concurrent.TimeUnit.SECONDS; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assumptions.assumeTrue; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.atLeast; @@ -169,8 +168,8 @@ private static OperationContext createOperationContext(final int connectTimeoutM @Test @DisplayName("should not call beginHandshake more than once during TLS session establishment") - void shouldNotCallBeginHandshakeMoreThenOnceDuringTlsSessionEstablishment() throws IOException, NoSuchAlgorithmException { - assumeTrue(ClusterFixture.getSslSettings().isEnabled()); + void shouldNotCallBeginHandshakeMoreThenOnceDuringTlsSessionEstablishment() throws Exception { + assumeTrue(ClusterFixture.getSslSettings().isEnabled()); //given try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver())) { @@ -186,12 +185,14 @@ void shouldNotCallBeginHandshakeMoreThenOnceDuringTlsSessionEstablishment() thro .build()); Stream stream = streamFactory.create(getPrimaryServerDescription().getAddress()); - stream.open(ClusterFixture.OPERATION_CONTEXT); ByteBuf wrap = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 3, 4})); + + //when stream.write(Collections.singletonList(wrap), ClusterFixture.OPERATION_CONTEXT); - sleep(1000); + //then + SECONDS.sleep(5); verify(singleResultSpyCaptor.getResult(), times(1)).beginHandshake(); } } From 92188dd055711b4609f4c1aca18a4ef03659a585 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 5 Jun 2025 23:02:36 -0700 Subject: [PATCH 5/5] Remove unused import. --- .../internal/connection/TlsChannelStreamFunctionalTest.java | 1 - 1 file changed, 1 deletion(-) diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java index ec0c2df5030..3af1eaa33e1 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java @@ -25,7 +25,6 @@ import com.mongodb.internal.TimeoutSettings; import org.bson.ByteBuf; import org.bson.ByteBufNIO; -import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest;