From cc5ea548f4d37fdfea61047467027ad9cad7018b Mon Sep 17 00:00:00 2001 From: Yuri Schimke Date: Thu, 21 Sep 2017 10:40:02 +0100 Subject: [PATCH] Fix requestChannel to not drop first payload (#398) --- .../main/java/io/rsocket/RSocketServer.java | 14 +++++--- .../io/rsocket/test/BaseClientServerTest.java | 34 ++++++++++++++++++- .../java/io/rsocket/test/TestRSocket.java | 7 ++++ 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java index deb981b4c..02e93b47b 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java @@ -16,13 +16,13 @@ package io.rsocket; +import static io.rsocket.Frame.Request.initialRequestN; import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_C; import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_M; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.util.collection.IntObjectHashMap; -import io.rsocket.Frame.Request; import io.rsocket.exceptions.ApplicationException; import io.rsocket.internal.LimitableRequestPublisher; import io.rsocket.util.PayloadImpl; @@ -157,7 +157,8 @@ private Mono handleFrame(Frame frame) { case REQUEST_N: return handleRequestN(streamId, frame); case REQUEST_STREAM: - return handleStream(streamId, requestStream(new PayloadImpl(frame)), frame); + return handleStream( + streamId, requestStream(new PayloadImpl(frame)), initialRequestN(frame)); case REQUEST_CHANNEL: return handleChannel(streamId, frame); case PAYLOAD: @@ -235,8 +236,7 @@ private Mono handleRequestResponse(int streamId, Mono response) { return responseFrame.flatMap(connection::sendOne); } - private Mono handleStream(int streamId, Flux response, Frame firstFrame) { - int initialRequestN = Request.initialRequestN(firstFrame); + private Mono handleStream(int streamId, Flux response, int initialRequestN) { Flux responseFrames = response .map(payload -> Frame.PayloadFrame.from(streamId, FrameType.NEXT, payload)) @@ -287,7 +287,11 @@ private Mono handleChannel(int streamId, Frame firstFrame) { }) .doFinally(signalType -> removeChannelProcessor(streamId)); - return handleStream(streamId, requestChannel(payloads), firstFrame); + // not chained, as the payload should be enqueued in the Unicast processor before this method returns + // and any later payload can be processed + frames.onNext(new PayloadImpl(firstFrame)); + + return handleStream(streamId, requestChannel(payloads), initialRequestN(firstFrame)); } private Mono handleKeepAliveFrame(Frame frame) { diff --git a/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java b/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java index 6ca2e8870..5c8e64264 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java @@ -20,6 +20,7 @@ import io.rsocket.Payload; import io.rsocket.util.PayloadImpl; +import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import reactor.core.publisher.Flux; @@ -79,7 +80,7 @@ public void testRequestResponse10() { assertEquals(10, outputCount); } - private PayloadImpl testPayload(int metadataPresent) { + private Payload testPayload(int metadataPresent) { String metadata; switch (metadataPresent % 5) { case 0: @@ -164,4 +165,35 @@ public void testRequestStreamWithDelayedRequestN() { assertEquals(10, ts.count()); } + + @Test(timeout = 10000) + @Ignore + public void testChannel0() { + Flux publisher = setup.getRSocket().requestChannel(Flux.empty()); + + long count = publisher.count().block(); + + assertEquals(0, count); + } + + @Test(timeout = 10000) + public void testChannel1() { + Flux publisher = setup.getRSocket().requestChannel(Flux.just(testPayload(0))); + + long count = publisher.count().block(); + + assertEquals(1, count); + } + + @Test(timeout = 10000) + public void testChannel3() { + Flux publisher = + setup + .getRSocket() + .requestChannel(Flux.just(testPayload(0), testPayload(1), testPayload(2))); + + long count = publisher.count().block(); + + assertEquals(3, count); + } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java index 4e89be889..5f75400e8 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java @@ -19,6 +19,7 @@ import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.util.PayloadImpl; +import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -43,4 +44,10 @@ public Mono metadataPush(Payload payload) { public Mono fireAndForget(Payload payload) { return Mono.empty(); } + + @Override + public Flux requestChannel(Publisher payloads) { + // TODO is defensive copy neccesary? + return Flux.from(payloads).map(p -> new PayloadImpl(p.getDataUtf8(), p.getMetadataUtf8())); + } }