Skip to content

Commit 1108514

Browse files
committed
Changed the FS2Channel to have input/output streams based on Message
Also, changed the Endpoint interfaces to carry over the InputMessage as parameters of their run functions. This should facilitate the proxification of servers, as the implementor of the proxy endpoint will be able to shove the input message onto the underlying's server's message queue.
1 parent 88b109a commit 1108514

File tree

7 files changed

+50
-39
lines changed

7 files changed

+50
-39
lines changed

core/src/jsonrpclib/Endpoint.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ object Endpoint {
1515
def apply[In, Err, Out](
1616
run: In => F[Either[Err, Out]]
1717
)(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): Endpoint[F] =
18-
RequestResponseEndpoint(method, (_: Method, in: In) => run(in), inCodec, errCodec, outCodec)
18+
RequestResponseEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec, errCodec, outCodec)
1919

2020
def full[In, Err, Out](
21-
run: (Method, In) => F[Either[Err, Out]]
21+
run: (InputMessage, In) => F[Either[Err, Out]]
2222
)(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): Endpoint[F] =
2323
RequestResponseEndpoint(method, run, inCodec, errCodec, outCodec)
2424

@@ -33,19 +33,22 @@ object Endpoint {
3333
)
3434

3535
def notification[In](run: In => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] =
36-
NotificationEndpoint(method, (_: Method, in: In) => run(in), inCodec)
36+
NotificationEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec)
3737

38-
def notificationFull[In](run: (Method, In) => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] =
38+
def notificationFull[In](run: (InputMessage, In) => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] =
3939
NotificationEndpoint(method, run, inCodec)
4040

4141
}
4242

43-
final case class NotificationEndpoint[F[_], In](method: Method, run: (Method, In) => F[Unit], inCodec: Codec[In])
44-
extends Endpoint[F]
43+
final case class NotificationEndpoint[F[_], In](
44+
method: MethodPattern,
45+
run: (InputMessage, In) => F[Unit],
46+
inCodec: Codec[In]
47+
) extends Endpoint[F]
4548

4649
final case class RequestResponseEndpoint[F[_], In, Err, Out](
4750
method: Method,
48-
run: (Method, In) => F[Either[Err, Out]],
51+
run: (InputMessage, In) => F[Either[Err, Out]],
4952
inCodec: Codec[In],
5053
errCodec: ErrorCodec[Err],
5154
outCodec: Codec[Out]

core/src/jsonrpclib/internals/Message.scala renamed to core/src/jsonrpclib/Message.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,32 @@
11
package jsonrpclib
2-
package internals
32

43
import com.github.plokhotnyuk.jsoniter_scala.core.JsonReader
54
import com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec
65
import com.github.plokhotnyuk.jsoniter_scala.core.JsonWriter
76

87
sealed trait Message { def maybeCallId: Option[CallId] }
9-
private[jsonrpclib] sealed trait InputMessage extends Message { def method: String }
10-
private[jsonrpclib] sealed trait OutputMessage extends Message {
8+
sealed trait InputMessage extends Message { def method: String }
9+
sealed trait OutputMessage extends Message {
1110
def callId: CallId; final override def maybeCallId: Option[CallId] = Some(callId)
1211
}
1312

14-
private[jsonrpclib] object InputMessage {
13+
object InputMessage {
1514
case class RequestMessage(method: String, callId: CallId, params: Option[Payload]) extends InputMessage {
1615
def maybeCallId: Option[CallId] = Some(callId)
1716
}
1817
case class NotificationMessage(method: String, params: Option[Payload]) extends InputMessage {
1918
def maybeCallId: Option[CallId] = None
2019
}
2120
}
22-
23-
private[jsonrpclib] object OutputMessage {
21+
object OutputMessage {
2422
def errorFrom(callId: CallId, protocolError: ProtocolError): OutputMessage =
2523
ErrorMessage(callId, ErrorPayload(protocolError.code, protocolError.getMessage(), None))
2624

2725
case class ErrorMessage(callId: CallId, payload: ErrorPayload) extends OutputMessage
2826
case class ResponseMessage(callId: CallId, data: Payload) extends OutputMessage
2927
}
3028

31-
private[jsonrpclib] object Message {
29+
object Message {
3230

3331
implicit val messageJsonValueCodecs: JsonValueCodec[Message] = new JsonValueCodec[Message] {
3432
val rawMessageCodec = implicitly[JsonValueCodec[internals.RawMessage]]

core/src/jsonrpclib/internals/MessageDispatcher.scala

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
package jsonrpclib
22
package internals
33

4-
import jsonrpclib.internals._
54
import jsonrpclib.Endpoint.NotificationEndpoint
65
import jsonrpclib.Endpoint.RequestResponseEndpoint
7-
import jsonrpclib.internals.OutputMessage.ErrorMessage
8-
import jsonrpclib.internals.OutputMessage.ResponseMessage
6+
import jsonrpclib.OutputMessage.ErrorMessage
7+
import jsonrpclib.OutputMessage.ResponseMessage
98
import scala.util.Try
109

1110
private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F]) extends Channel.MonadicChannel[F] {
@@ -41,8 +40,8 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F
4140
}
4241
}
4342

44-
protected[jsonrpclib] def handleReceivedPayload(payload: Payload): F[Unit] = {
45-
Codec.decode[Message](Some(payload)).map {
43+
protected[jsonrpclib] def handleReceivedMessage(message: Message): F[Unit] = {
44+
message match {
4645
case im: InputMessage =>
4746
doFlatMap(getEndpoint(im.method)) {
4847
case Some(ep) => background(im.maybeCallId, executeInputMessage(im, ep))
@@ -61,29 +60,25 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F
6160
case Some(pendingCall) => pendingCall(om)
6261
case None => doPure(()) // TODO do something
6362
}
64-
} match {
65-
case Left(error) =>
66-
sendProtocolError(error)
67-
case Right(dispatch) => dispatch
6863
}
6964
}
7065

71-
private def sendProtocolError(callId: CallId, pError: ProtocolError): F[Unit] =
66+
protected def sendProtocolError(callId: CallId, pError: ProtocolError): F[Unit] =
7267
sendMessage(OutputMessage.errorFrom(callId, pError))
73-
private def sendProtocolError(pError: ProtocolError): F[Unit] =
68+
protected def sendProtocolError(pError: ProtocolError): F[Unit] =
7469
sendProtocolError(CallId.NullId, pError)
7570

7671
private def executeInputMessage(input: InputMessage, endpoint: Endpoint[F]): F[Unit] = {
7772
(input, endpoint) match {
7873
case (InputMessage.NotificationMessage(_, params), ep: NotificationEndpoint[F, in]) =>
7974
ep.inCodec.decode(params) match {
80-
case Right(value) => ep.run(input.method, value)
75+
case Right(value) => ep.run(input, value)
8176
case Left(value) => reportError(params, value, ep.method)
8277
}
8378
case (InputMessage.RequestMessage(_, callId, params), ep: RequestResponseEndpoint[F, in, err, out]) =>
8479
ep.inCodec.decode(params) match {
8580
case Right(value) =>
86-
doFlatMap(ep.run(input.method, value)) {
81+
doFlatMap(ep.run(input, value)) {
8782
case Right(data) =>
8883
val responseData = ep.outCodec.encode(data)
8984
sendMessage(OutputMessage.ResponseMessage(callId, responseData))

examples/client/src/examples/client/ClientMain.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ object ClientMain extends IOApp.Simple {
4141
// Creating a channel that will be used to communicate to the server
4242
fs2Channel <- FS2Channel[IO](cancelTemplate = cancelEndpoint.some)
4343
_ <- Stream(())
44-
.concurrently(fs2Channel.output.through(lsp.encodePayloads).through(rp.stdin))
45-
.concurrently(rp.stdout.through(lsp.decodePayloads).through(fs2Channel.input))
44+
.concurrently(fs2Channel.output.through(lsp.encodeMessages).through(rp.stdin))
45+
.concurrently(rp.stdout.through(lsp.decodeMessages).through(fs2Channel.inputOrBounce))
4646
.concurrently(rp.stderr.through(fs2.io.stderr[IO]))
4747
// Creating a `IntWrapper => IO[IntWrapper]` stub that can call the server
4848
increment = fs2Channel.simpleStub[IntWrapper, IntWrapper]("increment")

examples/server/src/examples/server/ServerMain.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ object ServerMain extends IOApp.Simple {
3434
.flatMap(channel =>
3535
fs2.Stream
3636
.eval(IO.never) // running the server forever
37-
.concurrently(stdin[IO](512).through(lsp.decodePayloads).through(channel.input))
38-
.concurrently(channel.output.through(lsp.encodePayloads).through(stdout[IO]))
37+
.concurrently(stdin[IO](512).through(lsp.decodeMessages).through(channel.inputOrBounce))
38+
.concurrently(channel.output.through(lsp.encodeMessages).through(stdout[IO]))
3939
)
4040
.compile
4141
.drain

fs2/src/jsonrpclib/fs2/FS2Channel.scala

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ import cats.effect.std.Supervisor
1313
import cats.syntax.all._
1414
import cats.effect.syntax.all._
1515
import jsonrpclib.internals.MessageDispatcher
16-
import jsonrpclib.internals._
1716

1817
import scala.util.Try
1918
import java.util.regex.Pattern
2019

2120
trait FS2Channel[F[_]] extends Channel[F] {
2221

23-
def input: Pipe[F, Payload, Unit]
24-
def output: Stream[F, Payload]
22+
def input: Pipe[F, Message, Unit]
23+
def inputOrBounce: Pipe[F, Either[ProtocolError, Message], Unit]
24+
def output: Stream[F, Message]
2525

2626
def withEndpoint(endpoint: Endpoint[F])(implicit F: Functor[F]): Resource[F, FS2Channel[F]] =
2727
Resource.make(mountEndpoint(endpoint))(_ => unmountEndpoint(endpoint.method)).map(_ => this)
@@ -54,7 +54,7 @@ object FS2Channel {
5454
for {
5555
supervisor <- Stream.resource(Supervisor[F])
5656
ref <- Ref[F].of(State[F](Map.empty, Map.empty, Map.empty, Vector.empty, 0)).toStream
57-
queue <- cats.effect.std.Queue.bounded[F, Payload](bufferSize).toStream
57+
queue <- cats.effect.std.Queue.bounded[F, Message](bufferSize).toStream
5858
impl = new Impl(queue, ref, supervisor, cancelTemplate)
5959

6060
// Creating a bespoke endpoint to receive cancelation requests
@@ -116,16 +116,20 @@ object FS2Channel {
116116
}
117117

118118
private class Impl[F[_]](
119-
private val queue: cats.effect.std.Queue[F, Payload],
119+
private val queue: cats.effect.std.Queue[F, Message],
120120
private val state: Ref[F, FS2Channel.State[F]],
121121
supervisor: Supervisor[F],
122122
maybeCancelTemplate: Option[CancelTemplate]
123123
)(implicit F: Concurrent[F])
124124
extends MessageDispatcher[F]
125125
with FS2Channel[F] {
126126

127-
def output: Stream[F, Payload] = Stream.fromQueueUnterminated(queue)
128-
def input: Pipe[F, Payload, Unit] = _.evalMap(handleReceivedPayload)
127+
def output: Stream[F, Message] = Stream.fromQueueUnterminated(queue)
128+
def inputOrBounce: Pipe[F, Either[ProtocolError, Message], Unit] = _.evalMap {
129+
case Left(error) => sendProtocolError(error)
130+
case Right(message) => handleReceivedMessage(message)
131+
}
132+
def input: Pipe[F, Message, Unit] = _.evalMap(handleReceivedMessage)
129133

130134
def mountEndpoint(endpoint: Endpoint[F]): F[Unit] = state
131135
.modify(s =>
@@ -154,7 +158,7 @@ object FS2Channel {
154158
}
155159
protected def reportError(params: Option[Payload], error: ProtocolError, method: String): F[Unit] = ???
156160
protected def getEndpoint(method: String): F[Option[Endpoint[F]]] = state.get.map(_.getEndpoint(method))
157-
protected def sendMessage(message: Message): F[Unit] = queue.offer(Codec.encode(message))
161+
protected def sendMessage(message: Message): F[Unit] = queue.offer(message)
158162

159163
protected def nextCallId(): F[CallId] = state.modify(_.nextCallId)
160164
protected def createPromise[A](callId: CallId): F[(Try[A] => F[Unit], () => F[A])] = Deferred[F, Try[A]].map {

fs2/src/jsonrpclib/fs2/lsp.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,26 @@ import fs2.Chunk
66
import fs2.Stream
77
import fs2.Pipe
88
import jsonrpclib.Payload
9+
import jsonrpclib.Codec
910

1011
import java.nio.charset.Charset
1112
import java.nio.charset.StandardCharsets
13+
import jsonrpclib.Message
14+
import jsonrpclib.ProtocolError
1215

1316
object lsp {
1417

18+
def encodeMessages[F[_]]: Pipe[F, Message, Byte] =
19+
(_: Stream[F, Message]).map(Codec.encode(_)).through(encodePayloads)
20+
1521
def encodePayloads[F[_]]: Pipe[F, Payload, Byte] =
1622
(_: Stream[F, Payload]).map(writeChunk).flatMap(Stream.chunk(_))
1723

24+
def decodeMessages[F[_]: MonadThrow]: Pipe[F, Byte, Either[ProtocolError, Message]] =
25+
(_: Stream[F, Byte]).through(decodePayloads).map { payload =>
26+
Codec.decode[Message](Some(payload))
27+
}
28+
1829
/** Split a stream of bytes into payloads by extracting each frame based on information contained in the headers.
1930
*
2031
* See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#contentPart

0 commit comments

Comments
 (0)