Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: range scan more than one onError callback #205

Merged
merged 5 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -684,13 +683,13 @@ public void rangeScan(
@NonNull Set<RangeScanOption> options) {
gaugePendingRangeScanRequests.increment();

RangeScanConsumerWithShard timedConsumer =
new RangeScanConsumerWithShard() {
final RangeScanConsumer timedConsumer =
new RangeScanConsumer() {
final long startTime = System.nanoTime();
final AtomicLong totalSize = new AtomicLong();

@Override
public void onNext(long shardId, GetResult result) {
public void onNext(GetResult result) {
totalSize.addAndGet(result.getValue().length);
consumer.onNext(result);
}
Expand All @@ -703,7 +702,7 @@ public void onError(Throwable throwable) {
}

@Override
public void onCompleted(long shardId) {
public void onCompleted() {
gaugePendingRangeScanRequests.decrement();
counterRangeScanBytes.add(totalSize.longValue());
histogramRangeScanLatency.recordSuccess(System.nanoTime() - startTime);
Expand Down Expand Up @@ -731,20 +730,12 @@ public void onCompleted(long shardId) {
}
}

interface RangeScanConsumerWithShard {
void onNext(long shardId, GetResult result);

void onError(Throwable throwable);

void onCompleted(long shardId);
}

private void internalShardRangeScan(
long shardId,
String startKeyInclusive,
String endKeyExclusive,
Optional<String> secondaryIndexName,
RangeScanConsumerWithShard consumer) {
RangeScanConsumer consumer) {
var leader = shardManager.leader(shardId);
var stub = stubManager.getStub(leader);
var requestBuilder =
Expand All @@ -763,8 +754,7 @@ private void internalShardRangeScan(
@Override
public void onNext(RangeScanResponse response) {
for (int i = 0; i < response.getRecordsCount(); i++) {
consumer.onNext(
shardId, ProtoUtil.getResultFromProto("", response.getRecords(i)));
consumer.onNext(ProtoUtil.getResultFromProto("", response.getRecords(i)));
}
}

Expand All @@ -775,7 +765,7 @@ public void onError(Throwable t) {

@Override
public void onCompleted() {
consumer.onCompleted(shardId);
consumer.onCompleted();
}
});
}
Expand All @@ -784,41 +774,60 @@ private void internalRangeScanMultiShards(
String startKeyInclusive,
String endKeyExclusive,
Optional<String> secondaryIndexName,
RangeScanConsumerWithShard consumer) {
Set<Long> shardIds = shardManager.allShardIds();
RangeScanConsumer consumer) {
final Set<Long> shardIds = shardManager.allShardIds();
final RangeScanConsumer multiShardConsumer =
new SharedRangeScanConsumer(shardIds.size(), consumer);
for (long shardId : shardIds) {
internalShardRangeScan(
shardId, startKeyInclusive, endKeyExclusive, secondaryIndexName, multiShardConsumer);
}
}

RangeScanConsumerWithShard multiShardConsumer =
new RangeScanConsumerWithShard() {
private final Set<Long> pendingShards = new HashSet<>(shardIds);
private boolean failed = false;
static class SharedRangeScanConsumer implements RangeScanConsumer {
private final RangeScanConsumer delegate;

@Override
public synchronized void onNext(long shardId, GetResult result) {
if (!failed) {
consumer.onNext(shardId, result);
}
}
private int pendingCompletedRequests;
private boolean completed = false;
private Throwable completedException = null;

@Override
public synchronized void onError(Throwable throwable) {
failed = true;
consumer.onError(throwable);
}
SharedRangeScanConsumer(int shards, RangeScanConsumer delegate) {
this.pendingCompletedRequests = shards;
this.delegate = delegate;
}

@Override
public synchronized void onCompleted(long shardId) {
if (!failed) {
pendingShards.remove(shardId);
if (pendingShards.isEmpty()) {
consumer.onCompleted(shardId);
}
}
}
};
@Override
public synchronized void onNext(GetResult result) {
if (completed) {
return;
}
delegate.onNext(result);
}

for (long shardId : shardIds) {
internalShardRangeScan(
shardId, startKeyInclusive, endKeyExclusive, secondaryIndexName, multiShardConsumer);
@Override
public synchronized void onError(Throwable throwable) {
if (completedException == null) {
completedException = throwable;
} else {
completedException.addSuppressed(throwable);
}
if (completed) {
return;
}
completed = true;
delegate.onError(throwable);
}

@Override
public synchronized void onCompleted() {
if (completed) {
return;
}
pendingCompletedRequests -= 1;
if (pendingCompletedRequests == 0) {
completed = true;
delegate.onCompleted();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.streamnative.oxia.client.api.DeleteOption;
import io.streamnative.oxia.client.api.GetResult;
import io.streamnative.oxia.client.api.PutResult;
import io.streamnative.oxia.client.api.RangeScanConsumer;
import io.streamnative.oxia.client.api.Version;
import io.streamnative.oxia.client.batch.BatchManager;
import io.streamnative.oxia.client.batch.Batcher;
Expand All @@ -49,10 +50,18 @@
import io.streamnative.oxia.proto.ListResponse;
import io.streamnative.oxia.proto.OxiaClientGrpc;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.*;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand Down Expand Up @@ -566,4 +575,118 @@ void close() throws Exception {
inOrder.verify(stubManager).close();
client = null;
}

@Test
void testShardShardRangeScanConsumer() {
final int shards = 5;
final List<GetResult> results = new ArrayList<>();
final AtomicInteger onErrorCount = new AtomicInteger(0);
final AtomicInteger onCompletedCount = new AtomicInteger(0);
final Supplier<RangeScanConsumer> newShardRangeScanConsumer =
() ->
new AsyncOxiaClientImpl.SharedRangeScanConsumer(
5,
new RangeScanConsumer() {
@Override
public void onNext(GetResult result) {
results.add(result);
}

@Override
public void onError(Throwable throwable) {
onErrorCount.incrementAndGet();
}

@Override
public void onCompleted() {
onCompletedCount.incrementAndGet();
}
});
final var tasks = new ArrayList<ForkJoinTask<?>>();

// (1) complete ok
final var shardRangeScanConsumer1 = newShardRangeScanConsumer.get();
for (int i = 0; i < shards; i++) {
final int fi = i;
final ForkJoinTask<?> task =
ForkJoinPool.commonPool()
.submit(
() -> {
shardRangeScanConsumer1.onNext(
new GetResult(
"shard-" + fi + "-0",
new byte[10],
new Version(1, 2, 3, 4, empty(), empty())));
shardRangeScanConsumer1.onNext(
new GetResult(
"shard-" + fi + "-1",
new byte[10],
new Version(1, 2, 3, 4, empty(), empty())));
shardRangeScanConsumer1.onCompleted();
});
tasks.add(task);
}
tasks.forEach(ForkJoinTask::join);
var keys = results.stream().map(GetResult::getKey).toList();
for (int i = 0; i < shards; i++) {
Assertions.assertTrue(keys.contains("shard-" + i + "-0"));
Assertions.assertTrue(keys.contains("shard-" + i + "-1"));
}
Assertions.assertEquals(0, onErrorCount.get());
Assertions.assertEquals(1, onCompletedCount.get());

tasks.clear();
onErrorCount.set(0);
onCompletedCount.set(0);
results.clear();

// (2) complete partial exception
final var shardRangeScanConsumer2 = newShardRangeScanConsumer.get();
for (int i = 0; i < shards; i++) {
final int fi = i;
final ForkJoinTask<?> task =
ForkJoinPool.commonPool()
.submit(
() -> {
if (fi % 2 == 0) {
shardRangeScanConsumer2.onError(new IllegalStateException());
return;
}
shardRangeScanConsumer2.onNext(
new GetResult(
"shard-" + fi + "-0",
new byte[10],
new Version(1, 2, 3, 4, empty(), empty())));
shardRangeScanConsumer2.onNext(
new GetResult(
"shard-" + fi + "-1",
new byte[10],
new Version(1, 2, 3, 4, empty(), empty())));
shardRangeScanConsumer2.onCompleted();
});
tasks.add(task);
}
tasks.forEach(ForkJoinTask::join);

Assertions.assertEquals(1, onErrorCount.get());
Assertions.assertEquals(0, onCompletedCount.get());

tasks.clear();
onErrorCount.set(0);
onCompletedCount.set(0);
results.clear();

// (3) complete all exception
final var shardRangeScanConsumer3 = newShardRangeScanConsumer.get();
for (int i = 0; i < shards; i++) {
final ForkJoinTask<?> task =
ForkJoinPool.commonPool()
.submit(() -> shardRangeScanConsumer3.onError(new IllegalStateException()));
tasks.add(task);
}
tasks.forEach(ForkJoinTask::join);
Assertions.assertEquals(1, onErrorCount.get());
Assertions.assertEquals(0, onCompletedCount.get());
Assertions.assertEquals(0, results.size());
}
}
Loading