Skip to content

Commit 16b6e84

Browse files
authored
Merge pull request #47 from neandertech/globs
Revise endpoint and channel constructs to facilitate proxyfication
2 parents e412717 + 1108514 commit 16b6e84

File tree

8 files changed

+110
-44
lines changed

8 files changed

+110
-44
lines changed

core/src/jsonrpclib/Endpoint.scala

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,19 @@ sealed trait Endpoint[F[_]] {
66

77
object Endpoint {
88

9-
def apply[F[_]](method: String): PartiallyAppliedEndpoint[F] = new PartiallyAppliedEndpoint[F](method)
9+
type MethodPattern = String
10+
type Method = String
1011

11-
class PartiallyAppliedEndpoint[F[_]](method: String) {
12+
def apply[F[_]](method: Method): PartiallyAppliedEndpoint[F] = new PartiallyAppliedEndpoint[F](method)
13+
14+
class PartiallyAppliedEndpoint[F[_]](method: MethodPattern) {
1215
def apply[In, Err, Out](
1316
run: In => F[Either[Err, Out]]
17+
)(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): Endpoint[F] =
18+
RequestResponseEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec, errCodec, outCodec)
19+
20+
def full[In, Err, Out](
21+
run: (InputMessage, In) => F[Either[Err, Out]]
1422
)(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): Endpoint[F] =
1523
RequestResponseEndpoint(method, run, inCodec, errCodec, outCodec)
1624

@@ -25,16 +33,22 @@ object Endpoint {
2533
)
2634

2735
def notification[In](run: In => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] =
36+
NotificationEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec)
37+
38+
def notificationFull[In](run: (InputMessage, In) => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] =
2839
NotificationEndpoint(method, run, inCodec)
2940

3041
}
3142

32-
final case class NotificationEndpoint[F[_], In](method: String, run: In => F[Unit], inCodec: Codec[In])
33-
extends Endpoint[F]
43+
final case class NotificationEndpoint[F[_], In](
44+
method: MethodPattern,
45+
run: (InputMessage, In) => F[Unit],
46+
inCodec: Codec[In]
47+
) extends Endpoint[F]
3448

3549
final case class RequestResponseEndpoint[F[_], In, Err, Out](
36-
method: String,
37-
run: In => F[Either[Err, Out]],
50+
method: Method,
51+
run: (InputMessage, In) => F[Either[Err, Out]],
3852
inCodec: Codec[In],
3953
errCodec: ErrorCodec[Err],
4054
outCodec: Codec[Out]

core/src/jsonrpclib/internals/Message.scala renamed to core/src/jsonrpclib/Message.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,32 @@
11
package jsonrpclib
2-
package internals
32

43
import com.github.plokhotnyuk.jsoniter_scala.core.JsonReader
54
import com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec
65
import com.github.plokhotnyuk.jsoniter_scala.core.JsonWriter
76

87
sealed trait Message { def maybeCallId: Option[CallId] }
9-
private[jsonrpclib] sealed trait InputMessage extends Message { def method: String }
10-
private[jsonrpclib] sealed trait OutputMessage extends Message {
8+
sealed trait InputMessage extends Message { def method: String }
9+
sealed trait OutputMessage extends Message {
1110
def callId: CallId; final override def maybeCallId: Option[CallId] = Some(callId)
1211
}
1312

14-
private[jsonrpclib] object InputMessage {
13+
object InputMessage {
1514
case class RequestMessage(method: String, callId: CallId, params: Option[Payload]) extends InputMessage {
1615
def maybeCallId: Option[CallId] = Some(callId)
1716
}
1817
case class NotificationMessage(method: String, params: Option[Payload]) extends InputMessage {
1918
def maybeCallId: Option[CallId] = None
2019
}
2120
}
22-
23-
private[jsonrpclib] object OutputMessage {
21+
object OutputMessage {
2422
def errorFrom(callId: CallId, protocolError: ProtocolError): OutputMessage =
2523
ErrorMessage(callId, ErrorPayload(protocolError.code, protocolError.getMessage(), None))
2624

2725
case class ErrorMessage(callId: CallId, payload: ErrorPayload) extends OutputMessage
2826
case class ResponseMessage(callId: CallId, data: Payload) extends OutputMessage
2927
}
3028

31-
private[jsonrpclib] object Message {
29+
object Message {
3230

3331
implicit val messageJsonValueCodecs: JsonValueCodec[Message] = new JsonValueCodec[Message] {
3432
val rawMessageCodec = implicitly[JsonValueCodec[internals.RawMessage]]

core/src/jsonrpclib/internals/MessageDispatcher.scala

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
package jsonrpclib
22
package internals
33

4-
import jsonrpclib.internals._
54
import jsonrpclib.Endpoint.NotificationEndpoint
65
import jsonrpclib.Endpoint.RequestResponseEndpoint
7-
import jsonrpclib.internals.OutputMessage.ErrorMessage
8-
import jsonrpclib.internals.OutputMessage.ResponseMessage
6+
import jsonrpclib.OutputMessage.ErrorMessage
7+
import jsonrpclib.OutputMessage.ResponseMessage
98
import scala.util.Try
109

1110
private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F]) extends Channel.MonadicChannel[F] {
@@ -41,8 +40,8 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F
4140
}
4241
}
4342

44-
protected[jsonrpclib] def handleReceivedPayload(payload: Payload): F[Unit] = {
45-
Codec.decode[Message](Some(payload)).map {
43+
protected[jsonrpclib] def handleReceivedMessage(message: Message): F[Unit] = {
44+
message match {
4645
case im: InputMessage =>
4746
doFlatMap(getEndpoint(im.method)) {
4847
case Some(ep) => background(im.maybeCallId, executeInputMessage(im, ep))
@@ -61,29 +60,25 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F
6160
case Some(pendingCall) => pendingCall(om)
6261
case None => doPure(()) // TODO do something
6362
}
64-
} match {
65-
case Left(error) =>
66-
sendProtocolError(error)
67-
case Right(dispatch) => dispatch
6863
}
6964
}
7065

71-
private def sendProtocolError(callId: CallId, pError: ProtocolError): F[Unit] =
66+
protected def sendProtocolError(callId: CallId, pError: ProtocolError): F[Unit] =
7267
sendMessage(OutputMessage.errorFrom(callId, pError))
73-
private def sendProtocolError(pError: ProtocolError): F[Unit] =
68+
protected def sendProtocolError(pError: ProtocolError): F[Unit] =
7469
sendProtocolError(CallId.NullId, pError)
7570

7671
private def executeInputMessage(input: InputMessage, endpoint: Endpoint[F]): F[Unit] = {
7772
(input, endpoint) match {
7873
case (InputMessage.NotificationMessage(_, params), ep: NotificationEndpoint[F, in]) =>
7974
ep.inCodec.decode(params) match {
80-
case Right(value) => ep.run(value)
75+
case Right(value) => ep.run(input, value)
8176
case Left(value) => reportError(params, value, ep.method)
8277
}
8378
case (InputMessage.RequestMessage(_, callId, params), ep: RequestResponseEndpoint[F, in, err, out]) =>
8479
ep.inCodec.decode(params) match {
8580
case Right(value) =>
86-
doFlatMap(ep.run(value)) {
81+
doFlatMap(ep.run(input, value)) {
8782
case Right(data) =>
8883
val responseData = ep.outCodec.encode(data)
8984
sendMessage(OutputMessage.ResponseMessage(callId, responseData))

examples/client/src/examples/client/ClientMain.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ object ClientMain extends IOApp.Simple {
4141
// Creating a channel that will be used to communicate to the server
4242
fs2Channel <- FS2Channel[IO](cancelTemplate = cancelEndpoint.some)
4343
_ <- Stream(())
44-
.concurrently(fs2Channel.output.through(lsp.encodePayloads).through(rp.stdin))
45-
.concurrently(rp.stdout.through(lsp.decodePayloads).through(fs2Channel.input))
44+
.concurrently(fs2Channel.output.through(lsp.encodeMessages).through(rp.stdin))
45+
.concurrently(rp.stdout.through(lsp.decodeMessages).through(fs2Channel.inputOrBounce))
4646
.concurrently(rp.stderr.through(fs2.io.stderr[IO]))
4747
// Creating a `IntWrapper => IO[IntWrapper]` stub that can call the server
4848
increment = fs2Channel.simpleStub[IntWrapper, IntWrapper]("increment")

examples/server/src/examples/server/ServerMain.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ object ServerMain extends IOApp.Simple {
3434
.flatMap(channel =>
3535
fs2.Stream
3636
.eval(IO.never) // running the server forever
37-
.concurrently(stdin[IO](512).through(lsp.decodePayloads).through(channel.input))
38-
.concurrently(channel.output.through(lsp.encodePayloads).through(stdout[IO]))
37+
.concurrently(stdin[IO](512).through(lsp.decodeMessages).through(channel.inputOrBounce))
38+
.concurrently(channel.output.through(lsp.encodeMessages).through(stdout[IO]))
3939
)
4040
.compile
4141
.drain

fs2/src/jsonrpclib/fs2/FS2Channel.scala

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ import cats.effect.std.Supervisor
1313
import cats.syntax.all._
1414
import cats.effect.syntax.all._
1515
import jsonrpclib.internals.MessageDispatcher
16-
import jsonrpclib.internals._
1716

1817
import scala.util.Try
18+
import java.util.regex.Pattern
1919

2020
trait FS2Channel[F[_]] extends Channel[F] {
2121

22-
def input: Pipe[F, Payload, Unit]
23-
def output: Stream[F, Payload]
22+
def input: Pipe[F, Message, Unit]
23+
def inputOrBounce: Pipe[F, Either[ProtocolError, Message], Unit]
24+
def output: Stream[F, Message]
2425

2526
def withEndpoint(endpoint: Endpoint[F])(implicit F: Functor[F]): Resource[F, FS2Channel[F]] =
2627
Resource.make(mountEndpoint(endpoint))(_ => unmountEndpoint(endpoint.method)).map(_ => this)
@@ -52,8 +53,8 @@ object FS2Channel {
5253
): Stream[F, FS2Channel[F]] = {
5354
for {
5455
supervisor <- Stream.resource(Supervisor[F])
55-
ref <- Ref[F].of(State[F](Map.empty, Map.empty, Map.empty, 0)).toStream
56-
queue <- cats.effect.std.Queue.bounded[F, Payload](bufferSize).toStream
56+
ref <- Ref[F].of(State[F](Map.empty, Map.empty, Map.empty, Vector.empty, 0)).toStream
57+
queue <- cats.effect.std.Queue.bounded[F, Message](bufferSize).toStream
5758
impl = new Impl(queue, ref, supervisor, cancelTemplate)
5859

5960
// Creating a bespoke endpoint to receive cancelation requests
@@ -73,6 +74,7 @@ object FS2Channel {
7374
runningCalls: Map[CallId, Fiber[F, Throwable, Unit]],
7475
pendingCalls: Map[CallId, OutputMessage => F[Unit]],
7576
endpoints: Map[String, Endpoint[F]],
77+
globEndpoints: Vector[(Pattern, Endpoint[F])],
7678
counter: Long
7779
) {
7880
def nextCallId: (State[F], CallId) = (this.copy(counter = counter + 1), CallId.NumberId(counter))
@@ -82,11 +84,27 @@ object FS2Channel {
8284
val result = pendingCalls.get(callId)
8385
(this.copy(pendingCalls = pendingCalls.removed(callId)), result)
8486
}
85-
def mountEndpoint(endpoint: Endpoint[F]): Either[ConflictingMethodError, State[F]] =
86-
endpoints.get(endpoint.method) match {
87-
case None => Right(this.copy(endpoints = endpoints + (endpoint.method -> endpoint)))
88-
case Some(_) => Left(ConflictingMethodError(endpoint.method))
87+
def mountEndpoint(endpoint: Endpoint[F]): Either[ConflictingMethodError, State[F]] = {
88+
import endpoint.method
89+
if (method.contains("*")) {
90+
val parts = method
91+
.split("\\*", -1)
92+
.map { // Don't discard trailing empty string, if any.
93+
case "" => ""
94+
case str => Pattern.quote(str)
95+
}
96+
val glob = Pattern.compile(parts.mkString(".*"))
97+
Right(this.copy(globEndpoints = globEndpoints :+ (glob -> endpoint)))
98+
} else {
99+
endpoints.get(endpoint.method) match {
100+
case None => Right(this.copy(endpoints = endpoints + (endpoint.method -> endpoint)))
101+
case Some(_) => Left(ConflictingMethodError(endpoint.method))
102+
}
89103
}
104+
}
105+
def getEndpoint(method: String): Option[Endpoint[F]] = {
106+
endpoints.get(method).orElse(globEndpoints.find(_._1.matcher(method).matches()).map(_._2))
107+
}
90108
def removeEndpoint(method: String): State[F] =
91109
copy(endpoints = endpoints.removed(method))
92110

@@ -98,16 +116,20 @@ object FS2Channel {
98116
}
99117

100118
private class Impl[F[_]](
101-
private val queue: cats.effect.std.Queue[F, Payload],
119+
private val queue: cats.effect.std.Queue[F, Message],
102120
private val state: Ref[F, FS2Channel.State[F]],
103121
supervisor: Supervisor[F],
104122
maybeCancelTemplate: Option[CancelTemplate]
105123
)(implicit F: Concurrent[F])
106124
extends MessageDispatcher[F]
107125
with FS2Channel[F] {
108126

109-
def output: Stream[F, Payload] = Stream.fromQueueUnterminated(queue)
110-
def input: Pipe[F, Payload, Unit] = _.evalMap(handleReceivedPayload)
127+
def output: Stream[F, Message] = Stream.fromQueueUnterminated(queue)
128+
def inputOrBounce: Pipe[F, Either[ProtocolError, Message], Unit] = _.evalMap {
129+
case Left(error) => sendProtocolError(error)
130+
case Right(message) => handleReceivedMessage(message)
131+
}
132+
def input: Pipe[F, Message, Unit] = _.evalMap(handleReceivedMessage)
111133

112134
def mountEndpoint(endpoint: Endpoint[F]): F[Unit] = state
113135
.modify(s =>
@@ -135,8 +157,8 @@ object FS2Channel {
135157
}
136158
}
137159
protected def reportError(params: Option[Payload], error: ProtocolError, method: String): F[Unit] = ???
138-
protected def getEndpoint(method: String): F[Option[Endpoint[F]]] = state.get.map(_.endpoints.get(method))
139-
protected def sendMessage(message: Message): F[Unit] = queue.offer(Codec.encode(message))
160+
protected def getEndpoint(method: String): F[Option[Endpoint[F]]] = state.get.map(_.getEndpoint(method))
161+
protected def sendMessage(message: Message): F[Unit] = queue.offer(message)
140162

141163
protected def nextCallId(): F[CallId] = state.modify(_.nextCallId)
142164
protected def createPromise[A](callId: CallId): F[(Try[A] => F[Unit], () => F[A])] = Deferred[F, Try[A]].map {

fs2/src/jsonrpclib/fs2/lsp.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,26 @@ import fs2.Chunk
66
import fs2.Stream
77
import fs2.Pipe
88
import jsonrpclib.Payload
9+
import jsonrpclib.Codec
910

1011
import java.nio.charset.Charset
1112
import java.nio.charset.StandardCharsets
13+
import jsonrpclib.Message
14+
import jsonrpclib.ProtocolError
1215

1316
object lsp {
1417

18+
def encodeMessages[F[_]]: Pipe[F, Message, Byte] =
19+
(_: Stream[F, Message]).map(Codec.encode(_)).through(encodePayloads)
20+
1521
def encodePayloads[F[_]]: Pipe[F, Payload, Byte] =
1622
(_: Stream[F, Payload]).map(writeChunk).flatMap(Stream.chunk(_))
1723

24+
def decodeMessages[F[_]: MonadThrow]: Pipe[F, Byte, Either[ProtocolError, Message]] =
25+
(_: Stream[F, Byte]).through(decodePayloads).map { payload =>
26+
Codec.decode[Message](Some(payload))
27+
}
28+
1829
/** Split a stream of bytes into payloads by extracting each frame based on information contained in the headers.
1930
*
2031
* See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#contentPart

fs2/test/src/jsonrpclib/fs2/FS2ChannelSpec.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,32 @@ object FS2ChannelSpec extends SimpleIOSuite {
5454
}
5555
}
5656

57+
testRes("Round trip (glob)") {
58+
val endpoint: Endpoint[IO] = Endpoint[IO]("**").simple((int: IntWrapper) => IO(IntWrapper(int.int + 1)))
59+
60+
for {
61+
clientSideChannel <- setup(endpoint)
62+
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc/test")
63+
result <- remoteFunction(IntWrapper(1)).toStream
64+
} yield {
65+
expect.same(result, IntWrapper(2))
66+
}
67+
}
68+
69+
testRes("Globs have lower priority than strict endpoints") {
70+
val endpoint: Endpoint[IO] = Endpoint[IO]("inc").simple((int: IntWrapper) => IO(IntWrapper(int.int + 1)))
71+
val globEndpoint: Endpoint[IO] =
72+
Endpoint[IO]("**").simple((_: IntWrapper) => IO.raiseError[IntWrapper](new Throwable("Boom")))
73+
74+
for {
75+
clientSideChannel <- setup(globEndpoint, endpoint)
76+
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc")
77+
result <- remoteFunction(IntWrapper(1)).toStream
78+
} yield {
79+
expect.same(result, IntWrapper(2))
80+
}
81+
}
82+
5783
testRes("Endpoint not mounted") {
5884

5985
for {

0 commit comments

Comments
 (0)