diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java index 229470f018..29c1a6676c 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java @@ -27,19 +27,23 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Collection; import java.util.Iterator; import java.util.NoSuchElementException; +import java.util.Queue; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.concurrent.locks.LockSupport; import javax.annotation.Nullable; import static io.servicetalk.concurrent.api.SubscriberApiUtils.unwrapNullUnchecked; import static io.servicetalk.concurrent.api.SubscriberApiUtils.wrapNull; import static io.servicetalk.concurrent.internal.TerminalNotification.complete; import static io.servicetalk.concurrent.internal.TerminalNotification.error; -import static io.servicetalk.utils.internal.PlatformDependent.throwException; +import static io.servicetalk.utils.internal.PlatformDependent.newUnboundedSpscQueue; +import static io.servicetalk.utils.internal.ThrowableUtils.throwException; import static java.lang.Math.min; import static java.lang.Thread.currentThread; import static java.util.Objects.requireNonNull; @@ -101,7 +105,7 @@ private static final class SubscriberAndIterator implements Subscriber, Bl SubscriberAndIterator(int queueCapacity) { requestN = queueCapacity; - data = new LinkedBlockingQueue<>(); + data = new SpscBlockingQueue<>(newUnboundedSpscQueue(queueCapacity)); } @Override @@ -261,4 +265,305 @@ private T processNext() { return unwrapNullUnchecked(signal); } } + + private static final class SpscBlockingQueue implements BlockingQueue { + /** + * Amount of times to call {@link Thread#yield()} before calling {@link LockSupport#park()}. + * {@link LockSupport#park()} can be expensive and if the producer is generating data it is likely we will see + * it without park/unpark. + */ + private static final int POLL_YIELD_SPIN_COUNT = + Integer.getInteger("io.servicetalk.concurrent.internal.blockingIterableYieldSpinCount", 1); + /** + * Amount of nanoseconds to spin on {@link Thread#yield()} before calling {@link LockSupport#parkNanos(long)}. + * {@link LockSupport#parkNanos(long)} can be expensive and if the producer is generating data it is likely + * we will see it without park/unpark. + */ + private static final long POLL_YIELD_SPIN_NS = + Long.getLong("io.servicetalk.concurrent.internal.blockingIterableYieldSpinNs", 1024); + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater consumerThreadUpdater = + AtomicReferenceFieldUpdater.newUpdater(SpscBlockingQueue.class, Thread.class, "consumerThread"); + private static final Thread PRODUCED_THREAD = new Thread(() -> { }); + private final Queue spscQueue; + @Nullable + private volatile Thread consumerThread; + + SpscBlockingQueue(Queue spscQueue) { + this.spscQueue = requireNonNull(spscQueue); + } + + @Override + public boolean add(final T t) { + if (spscQueue.add(t)) { + signalConsumer(); + return true; + } + return false; + } + + @Override + public boolean offer(final T t) { + if (spscQueue.offer(t)) { + signalConsumer(); + return true; + } + return false; + } + + private void signalConsumer() { + final Thread thread = consumerThreadUpdater.getAndSet(this, PRODUCED_THREAD); + if (thread != null && thread != PRODUCED_THREAD) { + LockSupport.unpark(thread); + } + } + + @Override + public T remove() { + return spscQueue.remove(); + } + + @Override + public T poll() { + return spscQueue.poll(); + } + + @Override + public T element() { + final T t = poll(); + if (t == null) { + throw new NoSuchElementException(); + } + return t; + } + + @Override + public T peek() { + return spscQueue.peek(); + } + + @Override + public void put(final T t) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean offer(final T t, final long timeout, final TimeUnit unit) { + throw new UnsupportedOperationException(); + } + + @Override + public T take() throws InterruptedException { + final Thread currentThread = Thread.currentThread(); + for (;;) { + final Thread thread = consumerThread; + if (thread != null && thread != currentThread && thread != PRODUCED_THREAD) { + throwTooManyConsumers(currentThread); + } else if (thread == currentThread || + consumerThreadUpdater.compareAndSet(this, thread, currentThread)) { + try { + T item; + int pollCount = 0; + while ((item = spscQueue.poll()) == null) { + // Benchmarks show that park/unpark is expensive when producer is the EventLoop thread and + // unpark has to wakeup a thread that is parked. Yield has been shown to lower this cost + // on the EventLoop thread and increase throughput in these scenarios. + if (pollCount++ > POLL_YIELD_SPIN_COUNT) { + LockSupport.park(); + } else { + Thread.yield(); + } + checkInterrupted(); + } + + return item; + } finally { + // If this call changed the consumerThread before the poll call we should restore it after. + // This should be done atomically in case another thread has produced concurrently and swapped + // the value to PRODUCED_THREAD. + if (thread != currentThread) { + consumerThreadUpdater.compareAndSet(this, currentThread, null); + } + } + } + } + } + + @Override + public T poll(final long timeout, final TimeUnit unit) throws InterruptedException { + final Thread currentThread = Thread.currentThread(); + for (;;) { + final Thread thread = consumerThread; + if (thread != null && thread != currentThread && thread != PRODUCED_THREAD) { + throwTooManyConsumers(currentThread); + } else if (thread == currentThread || + consumerThreadUpdater.compareAndSet(this, thread, currentThread)) { + try { + final long originalNs = unit.toNanos(timeout); + long remainingNs = originalNs; + long beforeTimeNs = System.nanoTime(); + T item; + while ((item = spscQueue.poll()) == null) { + // Benchmarks show that park/unpark is expensive when producer is the EventLoop thread and + // unpark has to wakeup a thread that is parked. Yield has been shown to lower this cost + // on the EventLoop thread and increase throughput in these scenarios. + if (originalNs - remainingNs > POLL_YIELD_SPIN_NS) { + LockSupport.parkNanos(remainingNs); + } else { + Thread.yield(); + } + checkInterrupted(); + final long afterTimeNs = System.nanoTime(); + final long durationNs = afterTimeNs - beforeTimeNs; + if (durationNs > remainingNs) { + return null; + } + remainingNs -= durationNs; + beforeTimeNs = afterTimeNs; + } + + return item; + } finally { + // If this call changed the consumerThread before the poll call we should restore it after. + // This should be done atomically in case another thread has produced concurrently and swapped + // the value to PRODUCED_THREAD. + if (thread != currentThread) { + consumerThreadUpdater.compareAndSet(this, currentThread, null); + } + } + } + } + } + + private static void throwTooManyConsumers(Thread currentThread) { + throw new IllegalStateException("Only single consumer allowed, current consumer: " + currentThread); + } + + private static void checkInterrupted() throws InterruptedException { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + } + + @Override + public int remainingCapacity() { + return Integer.MAX_VALUE; + } + + @Override + public boolean remove(final Object o) { + if (spscQueue.remove(o)) { + signalConsumer(); + return true; + } + return false; + } + + @Override + public boolean containsAll(final Collection c) { + return spscQueue.containsAll(c); + } + + @Override + public boolean addAll(final Collection c) { + if (spscQueue.addAll(c)) { + signalConsumer(); + return true; + } + return false; + } + + @Override + public boolean removeAll(final Collection c) { + if (spscQueue.removeAll(c)) { + signalConsumer(); + return true; + } + return false; + } + + @Override + public boolean retainAll(final Collection c) { + if (spscQueue.retainAll(c)) { + signalConsumer(); + return true; + } + return false; + } + + @Override + public void clear() { + spscQueue.clear(); + signalConsumer(); + } + + @Override + public int size() { + return spscQueue.size(); + } + + @Override + public boolean isEmpty() { + return spscQueue.isEmpty(); + } + + @Override + public boolean contains(final Object o) { + return spscQueue.contains(o); + } + + @Override + public Iterator iterator() { + return spscQueue.iterator(); + } + + @Override + public Object[] toArray() { + return spscQueue.toArray(); + } + + @Override + public T1[] toArray(final T1[] a) { + return spscQueue.toArray(a); + } + + @Override + public int drainTo(final Collection c) { + int i = 0; + T item; + while ((item = poll()) != null) { + if (c.add(item)) { + ++i; + } + } + return i; + } + + @Override + public int drainTo(final Collection c, final int maxElements) { + int i = 0; + T item; + while (i < maxElements && (item = poll()) != null) { + if (c.add(item)) { + ++i; + } + } + return i; + } + + @Override + public boolean equals(Object o) { + return o instanceof SpscBlockingQueue && spscQueue.equals(((SpscBlockingQueue) o).spscQueue); + } + + @Override + public int hashCode() { + return spscQueue.hashCode(); + } + + @Override + public String toString() { + return spscQueue.toString(); + } + } }