Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: improve AsyncContext API to make it easier to add additional context propagation #3178

Closed
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright © 2025 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.concurrent.api;

import io.servicetalk.context.api.ContextMap;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

import java.util.concurrent.TimeUnit;
import java.util.function.Function;

/**
*
*/
@Fork(1)
@State(Scope.Benchmark)
@Warmup(iterations = 5, time = 3)
@Measurement(iterations = 5, time = 3)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@BenchmarkMode(Mode.AverageTime)
public class AsyncContextProviderBenchmark {

/**
* gc profiling of the DefaultAsyncContextProvider shows that the Scope based detachment can be stack allocated
* at least under some conditions.
*
* Benchmark Mode Cnt Score Error Units
* AsyncContextProviderBenchmark.contextRestoreCost avgt 5 3.932 ± 0.022 ns/op
* AsyncContextProviderBenchmark.contextRestoreCost:gc.alloc.rate avgt 5 ≈ 10⁻⁴ MB/sec
* AsyncContextProviderBenchmark.contextRestoreCost:gc.alloc.rate.norm avgt 5 ≈ 10⁻⁶ B/op
* AsyncContextProviderBenchmark.contextRestoreCost:gc.count avgt 5 ≈ 0 counts
* AsyncContextProviderBenchmark.contextSaveAndRestoreCost avgt 5 1.712 ± 0.005 ns/op
* AsyncContextProviderBenchmark.contextSaveAndRestoreCost:gc.alloc.rate avgt 5 ≈ 10⁻⁴ MB/sec
* AsyncContextProviderBenchmark.contextSaveAndRestoreCost:gc.alloc.rate.norm avgt 5 ≈ 10⁻⁷ B/op
* AsyncContextProviderBenchmark.contextSaveAndRestoreCost:gc.count avgt 5 ≈ 0 counts
*/

private static final ContextMap.Key<String> KEY = ContextMap.Key.newKey("test-key", String.class);
private static final String EXPECTED = "hello, world!";

private static Function<String, String> wrappedFunction;

@Setup
public void setup() {
// This will capture the current context
wrappedFunction = AsyncContext.wrapFunction(ignored -> AsyncContext.context().get(KEY));
AsyncContext.context().put(KEY, EXPECTED);
}

@Benchmark
public String contextRestoreCost() {
return wrappedFunction.apply("ignored");
}

@Benchmark
public String contextSaveAndRestoreCost() {
return AsyncContext.wrapFunction(Function.<String>identity()).apply("ignored");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collections;
import java.util.ConcurrentModificationException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
Expand All @@ -31,6 +33,7 @@
import java.util.function.BiPredicate;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.UnaryOperator;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.AsyncContextExecutorPlugin.EXECUTOR_PLUGIN;
Expand All @@ -51,6 +54,9 @@ public final class AsyncContext {
private static final int STATE_INIT = 0;
private static final int STATE_AUTO_ENABLED = 1;
private static final int STATE_ENABLED = 2;

private static final AsyncContextProvider DEFAULT_ENABLED_PROVIDER;

/**
* Note this mechanism is racy. Currently only the {@link #disable()} method is exposed publicly and
* {@link #STATE_DISABLED} is a terminal state. Because we favor going to the disabled state we don't have to worry
Expand All @@ -62,7 +68,17 @@ public final class AsyncContext {
* use case for this is a "once at start up" to {@link #disable()} this mechanism completely. This is currently a
* best effort mechanism for performance reasons, and we can re-evaluate later if more strict behavior is required.
*/
private static AsyncContextProvider provider = DefaultAsyncContextProvider.INSTANCE;
private static AsyncContextProvider provider;

static {
AsyncContextProvider result = DefaultAsyncContextProvider.INSTANCE;
List<UnaryOperator<AsyncContextProvider>> wrappers = asyncProviderWrappers();
for (UnaryOperator<AsyncContextProvider> wrapper : wrappers) {
result = wrapper.apply(result);
}
DEFAULT_ENABLED_PROVIDER = result;
provider = DEFAULT_ENABLED_PROVIDER;
}

private AsyncContext() {
// no instances
Expand Down Expand Up @@ -438,7 +454,7 @@ public static ScheduledExecutorService wrapJdkScheduledExecutorService(final Sch
*/
public static Runnable wrapRunnable(final Runnable runnable) {
AsyncContextProvider provider = provider();
return provider.wrapRunnable(runnable, provider.context());
return provider.wrapRunnable(runnable, provider.saveContext());
}

/**
Expand All @@ -449,7 +465,7 @@ public static Runnable wrapRunnable(final Runnable runnable) {
*/
public static <V> Callable<V> wrapCallable(final Callable<V> callable) {
AsyncContextProvider provider = provider();
return provider.wrapCallable(callable, provider.context());
return provider.wrapCallable(callable, provider.saveContext());
}

/**
Expand All @@ -460,7 +476,7 @@ public static <V> Callable<V> wrapCallable(final Callable<V> callable) {
*/
public static <T> Consumer<T> wrapConsumer(final Consumer<T> consumer) {
AsyncContextProvider provider = provider();
return provider.wrapConsumer(consumer, provider.context());
return provider.wrapConsumer(consumer, provider.saveContext());
}

/**
Expand All @@ -472,7 +488,7 @@ public static <T> Consumer<T> wrapConsumer(final Consumer<T> consumer) {
*/
public static <T, U> Function<T, U> wrapFunction(final Function<T, U> func) {
AsyncContextProvider provider = provider();
return provider.wrapFunction(func, provider.context());
return provider.wrapFunction(func, provider.saveContext());
}

/**
Expand All @@ -484,7 +500,7 @@ public static <T, U> Function<T, U> wrapFunction(final Function<T, U> func) {
*/
public static <T, U> BiConsumer<T, U> wrapBiConsume(final BiConsumer<T, U> consumer) {
AsyncContextProvider provider = provider();
return provider.wrapBiConsumer(consumer, provider.context());
return provider.wrapBiConsumer(consumer, provider.saveContext());
}

/**
Expand All @@ -497,7 +513,7 @@ public static <T, U> BiConsumer<T, U> wrapBiConsume(final BiConsumer<T, U> consu
*/
public static <T, U, V> BiFunction<T, U, V> wrapBiFunction(BiFunction<T, U, V> func) {
AsyncContextProvider provider = provider();
return provider.wrapBiFunction(func, provider.context());
return provider.wrapBiFunction(func, provider.saveContext());
}

/**
Expand Down Expand Up @@ -547,7 +563,7 @@ static void autoEnable() {
}

private static void enable0() {
provider = DefaultAsyncContextProvider.INSTANCE;
provider = DEFAULT_ENABLED_PROVIDER;
EXECUTOR_PLUGINS.add(EXECUTOR_PLUGIN);
LOGGER.debug("Enabled.");

Expand All @@ -561,4 +577,8 @@ private static void disable0() {
EXECUTOR_PLUGINS.remove(EXECUTOR_PLUGIN);
LOGGER.info("Disabled. Features that depend on AsyncContext will stop working.");
}

private static List<UnaryOperator<AsyncContextProvider>> asyncProviderWrappers() {
return Collections.emptyList();
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,30 @@ interface AsyncContextProvider {
/**
* Get the current context.
*
* Note that this method is for getting the {@link ContextMap} for use by the application code. For saving the
* current state for crossing an async boundary see the {@link AsyncContextProvider#saveContext()} method.
*
* @return The current context.
*/
ContextMap context();

/**
* Save existing context in preparation for an asynchronous thread jump.
*
* Note that this can do more than just package up the ServiceTalk {@link AsyncContext} and could be enhanced or
* wrapped to bundle up additional contexts such as the OpenTelemetry or grpc contexts.
* @return the saved context state that may be restored later.
*/
ContextMap saveContext();

/**
* Restore the previously saved {@link ContextMap} to the local state.
* @param contextMap representing the state previously saved via {@link AsyncContextProvider#saveContext()} and
* that is intended to be restored.
* @return a {@link Scope} that must be closed at the end of the attachment.
*/
Scope attachContext(ContextMap contextMap);

/**
* Wrap the {@link Cancellable} to ensure it is able to track {@link AsyncContext} correctly.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1730,7 +1730,7 @@ public final Future<Void> toFuture() {
*/
ContextMap contextForSubscribe(AsyncContextProvider provider) {
// the default behavior is to copy the map. Some operators may want to use shared map
return provider.context().copy();
return provider.saveContext().copy();
}

/**
Expand Down Expand Up @@ -2271,7 +2271,9 @@ private void subscribeWithContext(Subscriber subscriber,
handleSubscribe(wrapped, contextMap, contextProvider);
} else {
// Ensure that AsyncContext used for handleSubscribe() is the contextMap for the subscribe()
contextProvider.wrapRunnable(() -> handleSubscribe(wrapped, contextMap, contextProvider), contextMap).run();
try(Scope unused = contextProvider.attachContext(contextMap)) {
handleSubscribe(wrapped, contextMap, contextProvider);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ final class CompletableShareContextOnSubscribe extends AbstractNoHandleSubscribe

@Override
ContextMap contextForSubscribe(AsyncContextProvider provider) {
return provider.context();
return provider.saveContext();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import java.util.List;
import java.util.concurrent.Callable;

import static io.servicetalk.concurrent.api.DefaultAsyncContextProvider.INSTANCE;

final class ContextAwareExecutorUtils {

private ContextAwareExecutorUtils() {
Expand All @@ -32,7 +30,7 @@ private ContextAwareExecutorUtils() {

static <X> Collection<? extends Callable<X>> wrap(Collection<? extends Callable<X>> tasks) {
List<Callable<X>> wrappedTasks = new ArrayList<>(tasks.size());
ContextMap contextMap = INSTANCE.context();
ContextMap contextMap = AsyncContext.provider().saveContext();
for (Callable<X> task : tasks) {
wrappedTasks.add(new ContextPreservingCallable<>(task, contextMap));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
package io.servicetalk.concurrent.api;

import io.servicetalk.context.api.ContextMap;
import io.servicetalk.context.api.ContextMapHolder;

import java.util.function.BiConsumer;

import static io.servicetalk.concurrent.api.AsyncContextMapThreadLocal.CONTEXT_THREAD_LOCAL;
import static java.util.Objects.requireNonNull;

final class ContextPreservingBiConsumer<T, U> implements BiConsumer<T, U> {
Expand All @@ -34,28 +32,9 @@ final class ContextPreservingBiConsumer<T, U> implements BiConsumer<T, U> {

@Override
public void accept(T t, U u) {
final Thread currentThread = Thread.currentThread();
if (currentThread instanceof ContextMapHolder) {
final ContextMapHolder asyncContextMapHolder = (ContextMapHolder) currentThread;
ContextMap prev = asyncContextMapHolder.context();
try {
asyncContextMapHolder.context(saved);
delegate.accept(t, u);
} finally {
asyncContextMapHolder.context(prev);
}
} else {
slowPath(t, u);
}
}

private void slowPath(T t, U u) {
ContextMap prev = CONTEXT_THREAD_LOCAL.get();
try {
CONTEXT_THREAD_LOCAL.set(saved);
AsyncContextProvider provider = AsyncContext.provider();
try (Scope ignored = provider.attachContext(saved)) {
delegate.accept(t, u);
} finally {
CONTEXT_THREAD_LOCAL.set(prev);
}
}
}
Loading
Loading