Skip to content

Commit a7a3df5

Browse files
authored
Merge changes from tls-channel to prevent accidentally calling SSLEngine (#1726)
- Perform handshake after marking handshake started. - Add an integration test case, as upstream didn't include one to cover this change. JAVA-5797
1 parent 2a6e24f commit a7a3df5

File tree

2 files changed

+69
-6
lines changed

2 files changed

+69
-6
lines changed

driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ public TlsChannelImpl(
159159
private final Lock readLock = new ReentrantLock();
160160
private final Lock writeLock = new ReentrantLock();
161161

162-
private volatile boolean negotiated = false;
162+
private boolean handshakeStarted = false;
163+
164+
private volatile boolean handshakeCompleted = false;
163165

164166
/**
165167
* Whether a IOException was received from the underlying channel or from the {@link SSLEngine}.
@@ -526,22 +528,35 @@ public void handshake() throws IOException {
526528
}
527529

528530
private void doHandshake(boolean force) throws IOException, EofException {
529-
if (!force && negotiated) return;
531+
if (!force && handshakeCompleted) {
532+
return;
533+
}
530534
initLock.lock();
531535
try {
532536
if (invalid || shutdownSent) throw new ClosedChannelException();
533-
if (force || !negotiated) {
534-
engine.beginHandshake();
535-
LOGGER.trace("Called engine.beginHandshake()");
537+
if (force || !handshakeCompleted) {
538+
539+
if (!handshakeStarted) {
540+
engine.beginHandshake();
541+
LOGGER.trace("Called engine.beginHandshake()");
542+
543+
// Some engines that do not support renegotiations may be sensitive to calling
544+
// SSLEngine.beginHandshake() more than once. This guard prevents that.
545+
// See: https://github.com/marianobarrios/tls-channel/issues/197
546+
handshakeStarted = true;
547+
}
548+
536549
handshake(Optional.empty(), Optional.empty());
550+
551+
handshakeCompleted = true;
552+
537553
// call client code
538554
try {
539555
initSessionCallback.accept(engine.getSession());
540556
} catch (Exception e) {
541557
LOGGER.trace("client code threw exception in session initialization callback", e);
542558
throw new TlsChannelCallbackException("session initialization callback failed", e);
543559
}
544-
negotiated = true;
545560
}
546561
} finally {
547562
initLock.unlock();

driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,52 @@
1616

1717
package com.mongodb.internal.connection;
1818

19+
import com.mongodb.ClusterFixture;
1920
import com.mongodb.MongoSocketOpenException;
2021
import com.mongodb.ServerAddress;
2122
import com.mongodb.connection.SocketSettings;
2223
import com.mongodb.connection.SslSettings;
2324
import com.mongodb.internal.TimeoutContext;
2425
import com.mongodb.internal.TimeoutSettings;
26+
import org.bson.ByteBuf;
27+
import org.bson.ByteBufNIO;
28+
import org.junit.jupiter.api.DisplayName;
29+
import org.junit.jupiter.api.Test;
2530
import org.junit.jupiter.params.ParameterizedTest;
2631
import org.junit.jupiter.params.provider.ValueSource;
2732
import org.mockito.MockedStatic;
2833
import org.mockito.Mockito;
2934
import org.mockito.invocation.InvocationOnMock;
3035
import org.mockito.stubbing.Answer;
3136

37+
import javax.net.ssl.SSLContext;
38+
import javax.net.ssl.SSLEngine;
3239
import java.io.IOException;
3340
import java.net.ServerSocket;
41+
import java.nio.ByteBuffer;
3442
import java.nio.channels.InterruptedByTimeoutException;
3543
import java.nio.channels.SocketChannel;
44+
import java.util.Collections;
3645
import java.util.concurrent.TimeUnit;
3746

47+
import static com.mongodb.ClusterFixture.getPrimaryServerDescription;
3848
import static com.mongodb.internal.connection.OperationContext.simpleOperationContext;
3949
import static java.lang.String.format;
4050
import static java.util.concurrent.TimeUnit.MILLISECONDS;
51+
import static java.util.concurrent.TimeUnit.SECONDS;
4152
import static org.junit.jupiter.api.Assertions.assertFalse;
4253
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
4354
import static org.junit.jupiter.api.Assertions.assertNotNull;
4455
import static org.junit.jupiter.api.Assertions.assertThrows;
4556
import static org.junit.jupiter.api.Assertions.assertTrue;
4657
import static org.junit.jupiter.api.Assertions.fail;
58+
import static org.junit.jupiter.api.Assumptions.assumeTrue;
59+
import static org.mockito.ArgumentMatchers.anyInt;
60+
import static org.mockito.ArgumentMatchers.anyString;
4761
import static org.mockito.Mockito.atLeast;
62+
import static org.mockito.Mockito.times;
4863
import static org.mockito.Mockito.verify;
64+
import static org.mockito.Mockito.when;
4965

5066
class TlsChannelStreamFunctionalTest {
5167
private static final SslSettings SSL_SETTINGS = SslSettings.builder().enabled(true).build();
@@ -98,6 +114,7 @@ void shouldEstablishConnection(final int connectTimeoutMs) throws IOException, I
98114
try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver());
99115
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class);
100116
ServerSocket serverSocket = new ServerSocket(0, 1)) {
117+
101118
SingleResultSpyCaptor<SocketChannel> singleResultSpyCaptor = new SingleResultSpyCaptor<>();
102119
socketChannelMockedStatic.when(SocketChannel::open).thenAnswer(singleResultSpyCaptor);
103120

@@ -147,4 +164,35 @@ public T answer(final InvocationOnMock invocationOnMock) throws Throwable {
147164
private static OperationContext createOperationContext(final int connectTimeoutMs) {
148165
return simpleOperationContext(new TimeoutContext(TimeoutSettings.DEFAULT.withConnectTimeoutMS(connectTimeoutMs)));
149166
}
167+
168+
@Test
169+
@DisplayName("should not call beginHandshake more than once during TLS session establishment")
170+
void shouldNotCallBeginHandshakeMoreThenOnceDuringTlsSessionEstablishment() throws Exception {
171+
assumeTrue(ClusterFixture.getSslSettings().isEnabled());
172+
173+
//given
174+
try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver())) {
175+
176+
SSLContext sslContext = Mockito.spy(SSLContext.getDefault());
177+
SingleResultSpyCaptor<SSLEngine> singleResultSpyCaptor = new SingleResultSpyCaptor<>();
178+
when(sslContext.createSSLEngine(anyString(), anyInt())).thenAnswer(singleResultSpyCaptor);
179+
180+
StreamFactory streamFactory = streamFactoryFactory.create(
181+
SocketSettings.builder().build(),
182+
SslSettings.builder(ClusterFixture.getSslSettings())
183+
.context(sslContext)
184+
.build());
185+
186+
Stream stream = streamFactory.create(getPrimaryServerDescription().getAddress());
187+
stream.open(ClusterFixture.OPERATION_CONTEXT);
188+
ByteBuf wrap = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 3, 4}));
189+
190+
//when
191+
stream.write(Collections.singletonList(wrap), ClusterFixture.OPERATION_CONTEXT);
192+
193+
//then
194+
SECONDS.sleep(5);
195+
verify(singleResultSpyCaptor.getResult(), times(1)).beginHandshake();
196+
}
197+
}
150198
}

0 commit comments

Comments
 (0)