|
17 | 17 |
|
18 | 18 | import io.servicetalk.client.api.ConnectionFactory;
|
19 | 19 | import io.servicetalk.client.api.ConnectionLimitReachedException;
|
| 20 | +import io.servicetalk.client.api.DelegatingConnectionFactory; |
20 | 21 | import io.servicetalk.client.api.LoadBalancedConnection;
|
21 | 22 | import io.servicetalk.concurrent.api.AsyncContext;
|
22 | 23 | import io.servicetalk.concurrent.api.Completable;
|
23 | 24 | import io.servicetalk.concurrent.api.ListenableAsyncCloseable;
|
24 | 25 | import io.servicetalk.concurrent.api.Single;
|
| 26 | +import io.servicetalk.concurrent.api.TerminalSignalConsumer; |
25 | 27 | import io.servicetalk.concurrent.internal.DefaultContextMap;
|
26 | 28 | import io.servicetalk.concurrent.internal.DelayedCancellable;
|
27 | 29 | import io.servicetalk.context.api.ContextMap;
|
28 | 30 | import io.servicetalk.loadbalancer.LoadBalancerObserver.HostObserver;
|
| 31 | +import io.servicetalk.transport.api.TransportObserver; |
29 | 32 |
|
30 | 33 | import org.slf4j.Logger;
|
31 | 34 | import org.slf4j.LoggerFactory;
|
@@ -112,10 +115,12 @@ private enum State {
|
112 | 115 | this.lbDescription = requireNonNull(lbDescription, "lbDescription");
|
113 | 116 | this.address = requireNonNull(address, "address");
|
114 | 117 | this.linearSearchSpace = linearSearchSpace;
|
115 |
| - this.connectionFactory = requireNonNull(connectionFactory, "connectionFactory"); |
| 118 | + this.healthIndicator = healthIndicator; |
| 119 | + requireNonNull(connectionFactory, "connectionFactory"); |
| 120 | + this.connectionFactory = healthIndicator == null ? connectionFactory : |
| 121 | + new InstrumentedConnectionFactory<>(connectionFactory, healthIndicator); |
116 | 122 | this.healthCheckConfig = healthCheckConfig;
|
117 | 123 | this.hostObserver = requireNonNull(hostObserver, "hostObserver");
|
118 |
| - this.healthIndicator = healthIndicator; |
119 | 124 | this.closeable = toAsyncCloseable(this::doClose);
|
120 | 125 | hostObserver.onHostCreated(address);
|
121 | 126 | }
|
@@ -235,7 +240,7 @@ public Single<C> newConnection(
|
235 | 240 | Single<? extends C> establishConnection = connectionFactory.newConnection(address, actualContext, null);
|
236 | 241 | if (healthCheckConfig != null) {
|
237 | 242 | // Schedule health check before returning
|
238 |
| - establishConnection = establishConnection.beforeOnError(this::markUnhealthy); |
| 243 | + establishConnection = establishConnection.beforeOnError(this::onConnectionError); |
239 | 244 | }
|
240 | 245 | return establishConnection
|
241 | 246 | .flatMap(newCnx -> {
|
@@ -302,7 +307,7 @@ private void markHealthy(final HealthCheck originalHealthCheckState) {
|
302 | 307 | }
|
303 | 308 | }
|
304 | 309 |
|
305 |
| - private void markUnhealthy(final Throwable cause) { |
| 310 | + private void onConnectionError(Throwable cause) { |
306 | 311 | assert healthCheckConfig != null;
|
307 | 312 | for (;;) {
|
308 | 313 | ConnState previous = connStateUpdater.get(this);
|
@@ -646,4 +651,56 @@ public String toString() {
|
646 | 651 | '}';
|
647 | 652 | }
|
648 | 653 | }
|
| 654 | + |
| 655 | + private static final class InstrumentedConnectionFactory<Addr, C extends LoadBalancedConnection> |
| 656 | + extends DelegatingConnectionFactory<Addr, C> { |
| 657 | + |
| 658 | + private final ConnectTracker connectTracker; |
| 659 | + |
| 660 | + InstrumentedConnectionFactory(final ConnectionFactory<Addr, C> delegate, ConnectTracker connectTracker) { |
| 661 | + super(delegate); |
| 662 | + this.connectTracker = connectTracker; |
| 663 | + } |
| 664 | + |
| 665 | + @Override |
| 666 | + public Single<C> newConnection(Addr addr, @Nullable ContextMap context, @Nullable TransportObserver observer) { |
| 667 | + return Single.defer(() -> { |
| 668 | + final long connectStartTime = connectTracker.beforeConnectStart(); |
| 669 | + return delegate().newConnection(addr, context, observer) |
| 670 | + .beforeFinally(new ConnectSignalConsumer<>(connectStartTime, connectTracker)) |
| 671 | + .shareContextOnSubscribe(); |
| 672 | + }); |
| 673 | + } |
| 674 | + } |
| 675 | + |
| 676 | + private static class ConnectSignalConsumer<C extends LoadBalancedConnection> implements TerminalSignalConsumer { |
| 677 | + |
| 678 | + private final ConnectTracker connectTracker; |
| 679 | + private final long connectStartTime; |
| 680 | + |
| 681 | + ConnectSignalConsumer(final long connectStartTime, final ConnectTracker connectTracker) { |
| 682 | + this.connectStartTime = connectStartTime; |
| 683 | + this.connectTracker = connectTracker; |
| 684 | + } |
| 685 | + |
| 686 | + @Override |
| 687 | + public void onComplete() { |
| 688 | + connectTracker.onConnectSuccess(connectStartTime); |
| 689 | + } |
| 690 | + |
| 691 | + @Override |
| 692 | + public void cancel() { |
| 693 | + // We assume cancellation is the result of some sort of timeout. |
| 694 | + doOnError(); |
| 695 | + } |
| 696 | + |
| 697 | + @Override |
| 698 | + public void onError(Throwable t) { |
| 699 | + doOnError(); |
| 700 | + } |
| 701 | + |
| 702 | + private void doOnError() { |
| 703 | + connectTracker.onConnectError(connectStartTime); |
| 704 | + } |
| 705 | + } |
649 | 706 | }
|
0 commit comments