Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loadbalancer: Add ConnectTracker and make HealthIndicator extend it #2818

Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright © 2024 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.loadbalancer;

/**
* An interface for tracking connection establishment measurements.
* This has an intended usage similar to the {@link RequestTracker} but with a focus on connection establishment
* metrics.
*/
interface ConnectTracker {

/**
* Get the current time in nanoseconds.
* Note: this must not be a stateful API. Eg, it does not necessarily have a correlation with any other method call
* and such shouldn't be used as a method of counting in the same way that {@link RequestTracker} is used.
* @return the current time in nanoseconds.
*/
long beforeConnectStart();

/**
* Callback to notify the parent {@link HealthChecker} that an attempt to connect to this host has succeeded.
* @param beforeConnectStart the time that the connection attempt was initiated.
*/
void onConnectSuccess(long beforeConnectStart);

/**
* Callback to notify the parent {@link HealthChecker} that an attempt to connect to this host has failed.
* @param beforeConnectStart the time that the connection attempt was initiated.
*/
void onConnectError(long beforeConnectStart);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@

import io.servicetalk.client.api.ConnectionFactory;
import io.servicetalk.client.api.ConnectionLimitReachedException;
import io.servicetalk.client.api.DelegatingConnectionFactory;
import io.servicetalk.client.api.LoadBalancedConnection;
import io.servicetalk.concurrent.api.AsyncContext;
import io.servicetalk.concurrent.api.Completable;
import io.servicetalk.concurrent.api.ListenableAsyncCloseable;
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.concurrent.api.TerminalSignalConsumer;
import io.servicetalk.concurrent.internal.DefaultContextMap;
import io.servicetalk.concurrent.internal.DelayedCancellable;
import io.servicetalk.context.api.ContextMap;
import io.servicetalk.loadbalancer.LoadBalancerObserver.HostObserver;
import io.servicetalk.transport.api.TransportObserver;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -112,10 +115,12 @@ private enum State {
this.lbDescription = requireNonNull(lbDescription, "lbDescription");
this.address = requireNonNull(address, "address");
this.linearSearchSpace = linearSearchSpace;
this.connectionFactory = requireNonNull(connectionFactory, "connectionFactory");
this.healthIndicator = healthIndicator;
requireNonNull(connectionFactory, "connectionFactory");
this.connectionFactory = healthIndicator == null ? connectionFactory :
new InstrumentedConnectionFactory<>(connectionFactory, healthIndicator);
this.healthCheckConfig = healthCheckConfig;
this.hostObserver = requireNonNull(hostObserver, "hostObserver");
this.healthIndicator = healthIndicator;
this.closeable = toAsyncCloseable(this::doClose);
hostObserver.onHostCreated(address);
}
Expand Down Expand Up @@ -235,7 +240,7 @@ public Single<C> newConnection(
Single<? extends C> establishConnection = connectionFactory.newConnection(address, actualContext, null);
if (healthCheckConfig != null) {
// Schedule health check before returning
establishConnection = establishConnection.beforeOnError(this::markUnhealthy);
establishConnection = establishConnection.beforeOnError(this::onConnectionError);
}
return establishConnection
.flatMap(newCnx -> {
Expand Down Expand Up @@ -302,7 +307,7 @@ private void markHealthy(final HealthCheck originalHealthCheckState) {
}
}

private void markUnhealthy(final Throwable cause) {
private void onConnectionError(Throwable cause) {
assert healthCheckConfig != null;
for (;;) {
ConnState previous = connStateUpdater.get(this);
Expand Down Expand Up @@ -646,4 +651,55 @@ public String toString() {
'}';
}
}

private static final class InstrumentedConnectionFactory<Addr, C extends LoadBalancedConnection>
extends DelegatingConnectionFactory<Addr, C> {

private final ConnectTracker connectTracker;

InstrumentedConnectionFactory(final ConnectionFactory<Addr, C> delegate, ConnectTracker connectTracker) {
super(delegate);
this.connectTracker = connectTracker;
}

@Override
public Single<C> newConnection(Addr addr, @Nullable ContextMap context, @Nullable TransportObserver observer) {
return Single.defer(() -> {
final long connectStartTime = connectTracker.beforeConnectStart();
return delegate().newConnection(addr, context, observer)
.beforeFinally(new ConnectSignalConsumer<>(connectStartTime, connectTracker));
});
}
}

private static class ConnectSignalConsumer<C extends LoadBalancedConnection> implements TerminalSignalConsumer {

private final ConnectTracker connectTracker;
private final long connectStartTime;

ConnectSignalConsumer(final long connectStartTime, final ConnectTracker connectTracker) {
this.connectStartTime = connectStartTime;
this.connectTracker = connectTracker;
}

@Override
public void onComplete() {
connectTracker.onConnectSuccess(connectStartTime);
}

@Override
public void cancel() {
// We assume cancellation is the result of some sort of timeout.
doOnError();
}

@Override
public void onError(Throwable t) {
doOnError();
}

private void doOnError() {
connectTracker.onConnectError(connectStartTime);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,19 @@ abstract class DefaultRequestTracker implements RequestTracker, ScoreSupplier {
protected abstract long currentTimeNanos();

@Override
public final long beforeStart() {
public final long beforeRequestStart() {
pendingUpdater.incrementAndGet(this);
return currentTimeNanos();
}

@Override
public void onSuccess(final long startTimeNanos) {
public void onRequestSuccess(final long startTimeNanos) {
pendingUpdater.decrementAndGet(this);
calculateAndStore((ewma, currentLatency) -> currentLatency, startTimeNanos);
}

@Override
public void onError(final long startTimeNanos, ErrorClass errorClass) {
public void onRequestError(final long startTimeNanos, ErrorClass errorClass) {
pendingUpdater.decrementAndGet(this);
calculateAndStore(errorClass == ErrorClass.CANCELLED ? this:: cancelPenalty : this::errorPenalty,
startTimeNanos);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ public enum ErrorClass {
* Failures related to locally enforced timeouts that prevent session establishment with the peer.
*/
LOCAL_ORIGIN_TIMEOUT(true),
/**
* Failures related to connection establishment.
*/
LOCAL_ORIGIN_CONNECT_FAILED(true),

/**
* Failures caused locally, these would be things that failed due to an exception locally.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
* health check system can give the host information about it's perceived health and the host can give the
* health check system information about request results.
*/
interface HealthIndicator extends RequestTracker, ScoreSupplier, Cancellable {
interface HealthIndicator extends RequestTracker, ConnectTracker, ScoreSupplier, Cancellable {

/**
* Whether the host is considered healthy by the HealthIndicator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
* A tracker of latency of an action over time.
* <p>
* The usage of the RequestTracker is intended to follow the simple workflow:
* - At initiation of an action for which a request is must call {@link RequestTracker#beforeStart()} and save the
* timestamp much like would be done when using a stamped lock.
* - Once the request event is complete only one of the {@link RequestTracker#onSuccess(long)} or
* {@link RequestTracker#onError(long, ErrorClass)} methods must be called and called exactly once.
* In other words, every call to {@link RequestTracker#beforeStart()} must be followed by exactly one call to either of
* the completion methods {@link RequestTracker#onSuccess(long)} or
* {@link RequestTracker#onError(long, ErrorClass)}. Failure to do so can cause state corruption in the
* - At initiation of an action for which a request is must call {@link RequestTracker#beforeRequestStart()} and save
* the timestamp much like would be done when using a stamped lock.
* - Once the request event is complete only one of the {@link RequestTracker#onRequestSuccess(long)} or
* {@link RequestTracker#onRequestError(long, ErrorClass)} methods must be called and called exactly once.
* In other words, every call to {@link RequestTracker#beforeRequestStart()} must be followed by exactly one call to
* either of the completion methods {@link RequestTracker#onRequestSuccess(long)} or
* {@link RequestTracker#onRequestError(long, ErrorClass)}. Failure to do so can cause state corruption in the
* {@link RequestTracker} implementations which may track not just latency but also the outstanding requests.
*/
public interface RequestTracker {
Expand All @@ -40,20 +40,20 @@ public interface RequestTracker {
*
* @return Current time in nanoseconds.
*/
long beforeStart();
long beforeRequestStart();

/**
* Records a successful completion of the action for which latency is to be tracked.
*
* @param beforeStartTimeNs return value from {@link #beforeStart()}.
* @param beforeStartTimeNs return value from {@link #beforeRequestStart()}.
*/
void onSuccess(long beforeStartTimeNs);
void onRequestSuccess(long beforeStartTimeNs);

/**
* Records a failed completion of the action for which latency is to be tracked.
*
* @param beforeStartTimeNs return value from {@link #beforeStart()}.
* @param beforeStartTimeNs return value from {@link #beforeRequestStart()}.
* @param errorClass the class of error that triggered this method.
*/
void onError(long beforeStartTimeNs, ErrorClass errorClass);
void onRequestError(long beforeStartTimeNs, ErrorClass errorClass);
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,41 @@ public final boolean isHealthy() {
}

@Override
public final void onSuccess(final long beforeStartTimeNs) {
super.onSuccess(beforeStartTimeNs);
public final void onRequestSuccess(final long beforeStartTimeNs) {
super.onRequestSuccess(beforeStartTimeNs);
successes.incrementAndGet();
consecutive5xx.set(0);
LOGGER.trace("Observed success for address {}", address);
}

@Override
public final void onError(final long beforeStartTimeNs, ErrorClass errorClass) {
super.onError(beforeStartTimeNs, errorClass);
public final void onRequestError(final long beforeStartTimeNs, ErrorClass errorClass) {
super.onRequestError(beforeStartTimeNs, errorClass);
// For now, don't consider cancellation to be an error or a success.
if (errorClass == ErrorClass.CANCELLED) {
return;
if (errorClass != ErrorClass.CANCELLED) {
doOnError();
}
}

@Override
public long beforeConnectStart() {
return currentTimeNanos();
}

@Override
public void onConnectError(long beforeConnectStart) {
// This assumes that the connect request was intended to be used for a request dispatch which
// will have now failed. This is not strictly true: a connection can be acquired and simply not
// used, but in practice it's a very good assumption.
doOnError();
}

@Override
public void onConnectSuccess(long beforeConnectStart) {
// noop: the request path will now determine if the request was a success or failure.
}

private void doOnError() {
failures.incrementAndGet();
final int consecutiveFailures = consecutive5xx.incrementAndGet();
final OutlierDetectorConfig localConfig = currentConfig();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
import java.util.function.Predicate;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.Single.failed;
import static io.servicetalk.concurrent.api.Single.succeeded;
import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION;
import static io.servicetalk.loadbalancer.HealthCheckConfig.DEFAULT_HEALTH_CHECK_FAILED_CONNECTIONS_THRESHOLD;
import static io.servicetalk.loadbalancer.UnhealthyHostConnectionFactory.UNHEALTHY_HOST_EXCEPTION;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -233,4 +236,17 @@ void forwardsHealthIndicatorScore() {
assertThat(host.score(), is(10));
verify(healthIndicator, times(1)).score();
}

@Test
void connectFailuresAreForwardedToHealthIndicator() {
connectionFactory = new TestConnectionFactory(address -> failed(DELIBERATE_EXCEPTION));
HealthIndicator healthIndicator = mock(HealthIndicator.class);
buildHost(healthIndicator);
verify(mockHostObserver, times(1)).onHostCreated("address");
Throwable underlying = assertThrows(ExecutionException.class, () ->
host.newConnection(cxn -> true, false, null).toFuture().get()).getCause();
assertEquals(DELIBERATE_EXCEPTION, underlying);
verify(healthIndicator, times(1)).beforeConnectStart();
verify(healthIndicator, times(1)).onConnectError(0L);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,23 @@ public int score() {
}

@Override
public long beforeStart() {
public long beforeRequestStart() {
return 0;
}

@Override
public long beforeConnectStart() {
return 0;
}

@Override
public void onConnectSuccess(long beforeConnectStart) {
}

@Override
public void onConnectError(long beforeConnectStart) {
}

@Override
public void cancel() {
synchronized (indicatorSet) {
Expand All @@ -210,11 +223,11 @@ public boolean isHealthy() {
}

@Override
public void onSuccess(long beforeStartTime) {
public void onRequestSuccess(long beforeStartTime) {
}

@Override
public void onError(long beforeStartTime, ErrorClass errorClass) {
public void onRequestError(long beforeStartTime, ErrorClass errorClass) {
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ void test() {
Assertions.assertEquals(0, requestTracker.score());

// upon success score
requestTracker.onSuccess(requestTracker.beforeStart());
requestTracker.onRequestSuccess(requestTracker.beforeRequestStart());
Assertions.assertEquals(-500, requestTracker.score());

// error penalty
requestTracker.onError(requestTracker.beforeStart(), ErrorClass.LOCAL_ORIGIN_CONNECT_FAILED);
requestTracker.onRequestError(requestTracker.beforeRequestStart(), ErrorClass.EXT_ORIGIN_REQUEST_FAILED);
Assertions.assertEquals(-5000, requestTracker.score());

// cancellation penalty
requestTracker.onError(requestTracker.beforeStart(), ErrorClass.CANCELLED);
requestTracker.onRequestError(requestTracker.beforeRequestStart(), ErrorClass.CANCELLED);
Assertions.assertEquals(-12_500, requestTracker.score());

// decay
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ private void eject(HealthIndicator indicator) {
if (!indicator.isHealthy()) {
break;
}
long startTime = indicator.beforeStart();
indicator.onError(startTime + 1, ErrorClass.EXT_ORIGIN_REQUEST_FAILED);
long startTime = indicator.beforeRequestStart();
indicator.onRequestError(startTime + 1, ErrorClass.EXT_ORIGIN_REQUEST_FAILED);
}
}
}
Loading
Loading