|
16 | 16 |
|
17 | 17 | package com.mongodb.internal.connection;
|
18 | 18 |
|
| 19 | +import com.mongodb.ClusterFixture; |
19 | 20 | import com.mongodb.MongoSocketOpenException;
|
20 | 21 | import com.mongodb.ServerAddress;
|
21 | 22 | import com.mongodb.connection.SocketSettings;
|
22 | 23 | import com.mongodb.connection.SslSettings;
|
23 | 24 | import com.mongodb.internal.TimeoutContext;
|
24 | 25 | import com.mongodb.internal.TimeoutSettings;
|
| 26 | +import org.bson.ByteBuf; |
| 27 | +import org.bson.ByteBufNIO; |
| 28 | +import org.junit.jupiter.api.DisplayName; |
| 29 | +import org.junit.jupiter.api.Test; |
25 | 30 | import org.junit.jupiter.params.ParameterizedTest;
|
26 | 31 | import org.junit.jupiter.params.provider.ValueSource;
|
27 | 32 | import org.mockito.MockedStatic;
|
28 | 33 | import org.mockito.Mockito;
|
29 | 34 | import org.mockito.invocation.InvocationOnMock;
|
30 | 35 | import org.mockito.stubbing.Answer;
|
31 | 36 |
|
| 37 | +import javax.net.ssl.SSLContext; |
| 38 | +import javax.net.ssl.SSLEngine; |
32 | 39 | import java.io.IOException;
|
33 | 40 | import java.net.ServerSocket;
|
| 41 | +import java.nio.ByteBuffer; |
34 | 42 | import java.nio.channels.InterruptedByTimeoutException;
|
35 | 43 | import java.nio.channels.SocketChannel;
|
| 44 | +import java.util.Collections; |
36 | 45 | import java.util.concurrent.TimeUnit;
|
37 | 46 |
|
| 47 | +import static com.mongodb.ClusterFixture.getPrimaryServerDescription; |
38 | 48 | import static com.mongodb.internal.connection.OperationContext.simpleOperationContext;
|
39 | 49 | import static java.lang.String.format;
|
40 | 50 | import static java.util.concurrent.TimeUnit.MILLISECONDS;
|
| 51 | +import static java.util.concurrent.TimeUnit.SECONDS; |
41 | 52 | import static org.junit.jupiter.api.Assertions.assertFalse;
|
42 | 53 | import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
43 | 54 | import static org.junit.jupiter.api.Assertions.assertNotNull;
|
44 | 55 | import static org.junit.jupiter.api.Assertions.assertThrows;
|
45 | 56 | import static org.junit.jupiter.api.Assertions.assertTrue;
|
46 | 57 | import static org.junit.jupiter.api.Assertions.fail;
|
| 58 | +import static org.junit.jupiter.api.Assumptions.assumeTrue; |
| 59 | +import static org.mockito.ArgumentMatchers.anyInt; |
| 60 | +import static org.mockito.ArgumentMatchers.anyString; |
47 | 61 | import static org.mockito.Mockito.atLeast;
|
| 62 | +import static org.mockito.Mockito.times; |
48 | 63 | import static org.mockito.Mockito.verify;
|
| 64 | +import static org.mockito.Mockito.when; |
49 | 65 |
|
50 | 66 | class TlsChannelStreamFunctionalTest {
|
51 | 67 | private static final SslSettings SSL_SETTINGS = SslSettings.builder().enabled(true).build();
|
@@ -98,6 +114,7 @@ void shouldEstablishConnection(final int connectTimeoutMs) throws IOException, I
|
98 | 114 | try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver());
|
99 | 115 | MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class);
|
100 | 116 | ServerSocket serverSocket = new ServerSocket(0, 1)) {
|
| 117 | + |
101 | 118 | SingleResultSpyCaptor<SocketChannel> singleResultSpyCaptor = new SingleResultSpyCaptor<>();
|
102 | 119 | socketChannelMockedStatic.when(SocketChannel::open).thenAnswer(singleResultSpyCaptor);
|
103 | 120 |
|
@@ -147,4 +164,35 @@ public T answer(final InvocationOnMock invocationOnMock) throws Throwable {
|
147 | 164 | private static OperationContext createOperationContext(final int connectTimeoutMs) {
|
148 | 165 | return simpleOperationContext(new TimeoutContext(TimeoutSettings.DEFAULT.withConnectTimeoutMS(connectTimeoutMs)));
|
149 | 166 | }
|
| 167 | + |
| 168 | + @Test |
| 169 | + @DisplayName("should not call beginHandshake more than once during TLS session establishment") |
| 170 | + void shouldNotCallBeginHandshakeMoreThenOnceDuringTlsSessionEstablishment() throws Exception { |
| 171 | + assumeTrue(ClusterFixture.getSslSettings().isEnabled()); |
| 172 | + |
| 173 | + //given |
| 174 | + try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver())) { |
| 175 | + |
| 176 | + SSLContext sslContext = Mockito.spy(SSLContext.getDefault()); |
| 177 | + SingleResultSpyCaptor<SSLEngine> singleResultSpyCaptor = new SingleResultSpyCaptor<>(); |
| 178 | + when(sslContext.createSSLEngine(anyString(), anyInt())).thenAnswer(singleResultSpyCaptor); |
| 179 | + |
| 180 | + StreamFactory streamFactory = streamFactoryFactory.create( |
| 181 | + SocketSettings.builder().build(), |
| 182 | + SslSettings.builder(ClusterFixture.getSslSettings()) |
| 183 | + .context(sslContext) |
| 184 | + .build()); |
| 185 | + |
| 186 | + Stream stream = streamFactory.create(getPrimaryServerDescription().getAddress()); |
| 187 | + stream.open(ClusterFixture.OPERATION_CONTEXT); |
| 188 | + ByteBuf wrap = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 3, 4})); |
| 189 | + |
| 190 | + //when |
| 191 | + stream.write(Collections.singletonList(wrap), ClusterFixture.OPERATION_CONTEXT); |
| 192 | + |
| 193 | + //then |
| 194 | + SECONDS.sleep(5); |
| 195 | + verify(singleResultSpyCaptor.getResult(), times(1)).beginHandshake(); |
| 196 | + } |
| 197 | + } |
150 | 198 | }
|
0 commit comments