From 0a1fbddb0e40bf4b860a5c153fbdd7cf618d462f Mon Sep 17 00:00:00 2001 From: Scott Mitchell Date: Sat, 1 Oct 2022 13:44:26 -0700 Subject: [PATCH] PublisherAsBlockingIterable LinkedBlockingQueue -> SpscBlockingQueue Motivation: LinkedBlockingQueue allows for multiple producers and multiple consumers. It uses LockSupport park in offer and unpark in take. LockSupport unpark on the EventLoop thread has been shown to impact throughput during benchmarks. Before: ``` Running 30s test @ http://localhost:8080/medium, using 'ServiceTalkGrpcBlockingClientStrAgg' client 1024 threads and 1024 connections Thread Stats Avg Stdev Max +/- Stdev Latency - - - - Req/Sec 0.01k - 0.01k - 262338 requests in 30s Requests/sec: 8744.60 Transfer/sec: - OK: 262338 KO: 0 ``` After: ``` Running 30s test @ http://localhost:8080/medium, using 'ServiceTalkGrpcBlockingClientStrAgg' client 1024 threads and 1024 connections Thread Stats Avg Stdev Max +/- Stdev Latency - - - - Req/Sec 0.01k - 0.01k - 326478 requests in 30s Requests/sec: 10882.60 Transfer/sec: - OK: 326478 KO: 0 ``` --- .../api/PublisherAsBlockingIterable.java | 311 +++++++++++++++++- 1 file changed, 308 insertions(+), 3 deletions(-) diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java index 229470f018..29c1a6676c 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java @@ -27,19 +27,23 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Collection; import java.util.Iterator; import java.util.NoSuchElementException; +import java.util.Queue; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.concurrent.locks.LockSupport; import javax.annotation.Nullable; import static io.servicetalk.concurrent.api.SubscriberApiUtils.unwrapNullUnchecked; import static io.servicetalk.concurrent.api.SubscriberApiUtils.wrapNull; import static io.servicetalk.concurrent.internal.TerminalNotification.complete; import static io.servicetalk.concurrent.internal.TerminalNotification.error; -import static io.servicetalk.utils.internal.PlatformDependent.throwException; +import static io.servicetalk.utils.internal.PlatformDependent.newUnboundedSpscQueue; +import static io.servicetalk.utils.internal.ThrowableUtils.throwException; import static java.lang.Math.min; import static java.lang.Thread.currentThread; import static java.util.Objects.requireNonNull; @@ -101,7 +105,7 @@ private static final class SubscriberAndIterator implements Subscriber, Bl SubscriberAndIterator(int queueCapacity) { requestN = queueCapacity; - data = new LinkedBlockingQueue<>(); + data = new SpscBlockingQueue<>(newUnboundedSpscQueue(queueCapacity)); } @Override @@ -261,4 +265,305 @@ private T processNext() { return unwrapNullUnchecked(signal); } } + + private static final class SpscBlockingQueue implements BlockingQueue { + /** + * Amount of times to call {@link Thread#yield()} before calling {@link LockSupport#park()}. + * {@link LockSupport#park()} can be expensive and if the producer is generating data it is likely we will see + * it without park/unpark. + */ + private static final int POLL_YIELD_SPIN_COUNT = + Integer.getInteger("io.servicetalk.concurrent.internal.blockingIterableYieldSpinCount", 1); + /** + * Amount of nanoseconds to spin on {@link Thread#yield()} before calling {@link LockSupport#parkNanos(long)}. + * {@link LockSupport#parkNanos(long)} can be expensive and if the producer is generating data it is likely + * we will see it without park/unpark. + */ + private static final long POLL_YIELD_SPIN_NS = + Long.getLong("io.servicetalk.concurrent.internal.blockingIterableYieldSpinNs", 1024); + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater consumerThreadUpdater = + AtomicReferenceFieldUpdater.newUpdater(SpscBlockingQueue.class, Thread.class, "consumerThread"); + private static final Thread PRODUCED_THREAD = new Thread(() -> { }); + private final Queue spscQueue; + @Nullable + private volatile Thread consumerThread; + + SpscBlockingQueue(Queue spscQueue) { + this.spscQueue = requireNonNull(spscQueue); + } + + @Override + public boolean add(final T t) { + if (spscQueue.add(t)) { + signalConsumer(); + return true; + } + return false; + } + + @Override + public boolean offer(final T t) { + if (spscQueue.offer(t)) { + signalConsumer(); + return true; + } + return false; + } + + private void signalConsumer() { + final Thread thread = consumerThreadUpdater.getAndSet(this, PRODUCED_THREAD); + if (thread != null && thread != PRODUCED_THREAD) { + LockSupport.unpark(thread); + } + } + + @Override + public T remove() { + return spscQueue.remove(); + } + + @Override + public T poll() { + return spscQueue.poll(); + } + + @Override + public T element() { + final T t = poll(); + if (t == null) { + throw new NoSuchElementException(); + } + return t; + } + + @Override + public T peek() { + return spscQueue.peek(); + } + + @Override + public void put(final T t) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean offer(final T t, final long timeout, final TimeUnit unit) { + throw new UnsupportedOperationException(); + } + + @Override + public T take() throws InterruptedException { + final Thread currentThread = Thread.currentThread(); + for (;;) { + final Thread thread = consumerThread; + if (thread != null && thread != currentThread && thread != PRODUCED_THREAD) { + throwTooManyConsumers(currentThread); + } else if (thread == currentThread || + consumerThreadUpdater.compareAndSet(this, thread, currentThread)) { + try { + T item; + int pollCount = 0; + while ((item = spscQueue.poll()) == null) { + // Benchmarks show that park/unpark is expensive when producer is the EventLoop thread and + // unpark has to wakeup a thread that is parked. Yield has been shown to lower this cost + // on the EventLoop thread and increase throughput in these scenarios. + if (pollCount++ > POLL_YIELD_SPIN_COUNT) { + LockSupport.park(); + } else { + Thread.yield(); + } + checkInterrupted(); + } + + return item; + } finally { + // If this call changed the consumerThread before the poll call we should restore it after. + // This should be done atomically in case another thread has produced concurrently and swapped + // the value to PRODUCED_THREAD. + if (thread != currentThread) { + consumerThreadUpdater.compareAndSet(this, currentThread, null); + } + } + } + } + } + + @Override + public T poll(final long timeout, final TimeUnit unit) throws InterruptedException { + final Thread currentThread = Thread.currentThread(); + for (;;) { + final Thread thread = consumerThread; + if (thread != null && thread != currentThread && thread != PRODUCED_THREAD) { + throwTooManyConsumers(currentThread); + } else if (thread == currentThread || + consumerThreadUpdater.compareAndSet(this, thread, currentThread)) { + try { + final long originalNs = unit.toNanos(timeout); + long remainingNs = originalNs; + long beforeTimeNs = System.nanoTime(); + T item; + while ((item = spscQueue.poll()) == null) { + // Benchmarks show that park/unpark is expensive when producer is the EventLoop thread and + // unpark has to wakeup a thread that is parked. Yield has been shown to lower this cost + // on the EventLoop thread and increase throughput in these scenarios. + if (originalNs - remainingNs > POLL_YIELD_SPIN_NS) { + LockSupport.parkNanos(remainingNs); + } else { + Thread.yield(); + } + checkInterrupted(); + final long afterTimeNs = System.nanoTime(); + final long durationNs = afterTimeNs - beforeTimeNs; + if (durationNs > remainingNs) { + return null; + } + remainingNs -= durationNs; + beforeTimeNs = afterTimeNs; + } + + return item; + } finally { + // If this call changed the consumerThread before the poll call we should restore it after. + // This should be done atomically in case another thread has produced concurrently and swapped + // the value to PRODUCED_THREAD. + if (thread != currentThread) { + consumerThreadUpdater.compareAndSet(this, currentThread, null); + } + } + } + } + } + + private static void throwTooManyConsumers(Thread currentThread) { + throw new IllegalStateException("Only single consumer allowed, current consumer: " + currentThread); + } + + private static void checkInterrupted() throws InterruptedException { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + } + + @Override + public int remainingCapacity() { + return Integer.MAX_VALUE; + } + + @Override + public boolean remove(final Object o) { + if (spscQueue.remove(o)) { + signalConsumer(); + return true; + } + return false; + } + + @Override + public boolean containsAll(final Collection c) { + return spscQueue.containsAll(c); + } + + @Override + public boolean addAll(final Collection c) { + if (spscQueue.addAll(c)) { + signalConsumer(); + return true; + } + return false; + } + + @Override + public boolean removeAll(final Collection c) { + if (spscQueue.removeAll(c)) { + signalConsumer(); + return true; + } + return false; + } + + @Override + public boolean retainAll(final Collection c) { + if (spscQueue.retainAll(c)) { + signalConsumer(); + return true; + } + return false; + } + + @Override + public void clear() { + spscQueue.clear(); + signalConsumer(); + } + + @Override + public int size() { + return spscQueue.size(); + } + + @Override + public boolean isEmpty() { + return spscQueue.isEmpty(); + } + + @Override + public boolean contains(final Object o) { + return spscQueue.contains(o); + } + + @Override + public Iterator iterator() { + return spscQueue.iterator(); + } + + @Override + public Object[] toArray() { + return spscQueue.toArray(); + } + + @Override + public T1[] toArray(final T1[] a) { + return spscQueue.toArray(a); + } + + @Override + public int drainTo(final Collection c) { + int i = 0; + T item; + while ((item = poll()) != null) { + if (c.add(item)) { + ++i; + } + } + return i; + } + + @Override + public int drainTo(final Collection c, final int maxElements) { + int i = 0; + T item; + while (i < maxElements && (item = poll()) != null) { + if (c.add(item)) { + ++i; + } + } + return i; + } + + @Override + public boolean equals(Object o) { + return o instanceof SpscBlockingQueue && spscQueue.equals(((SpscBlockingQueue) o).spscQueue); + } + + @Override + public int hashCode() { + return spscQueue.hashCode(); + } + + @Override + public String toString() { + return spscQueue.toString(); + } + } }