Skip to content

Commit b591452

Browse files
committed
Add support for endpoints with glob-pattern methods
1 parent e09c4ec commit b591452

File tree

2 files changed

+50
-6
lines changed

2 files changed

+50
-6
lines changed

fs2/src/jsonrpclib/fs2/FS2Channel.scala

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import jsonrpclib.internals.MessageDispatcher
1616
import jsonrpclib.internals._
1717

1818
import scala.util.Try
19+
import java.util.regex.Pattern
1920

2021
trait FS2Channel[F[_]] extends Channel[F] {
2122

@@ -52,7 +53,7 @@ object FS2Channel {
5253
): Stream[F, FS2Channel[F]] = {
5354
for {
5455
supervisor <- Stream.resource(Supervisor[F])
55-
ref <- Ref[F].of(State[F](Map.empty, Map.empty, Map.empty, 0)).toStream
56+
ref <- Ref[F].of(State[F](Map.empty, Map.empty, Map.empty, Vector.empty, 0)).toStream
5657
queue <- cats.effect.std.Queue.bounded[F, Payload](bufferSize).toStream
5758
impl = new Impl(queue, ref, supervisor, cancelTemplate)
5859

@@ -73,6 +74,7 @@ object FS2Channel {
7374
runningCalls: Map[CallId, Fiber[F, Throwable, Unit]],
7475
pendingCalls: Map[CallId, OutputMessage => F[Unit]],
7576
endpoints: Map[String, Endpoint[F]],
77+
globEndpoints: Vector[(Pattern, Endpoint[F])],
7678
counter: Long
7779
) {
7880
def nextCallId: (State[F], CallId) = (this.copy(counter = counter + 1), CallId.NumberId(counter))
@@ -82,11 +84,27 @@ object FS2Channel {
8284
val result = pendingCalls.get(callId)
8385
(this.copy(pendingCalls = pendingCalls.removed(callId)), result)
8486
}
85-
def mountEndpoint(endpoint: Endpoint[F]): Either[ConflictingMethodError, State[F]] =
86-
endpoints.get(endpoint.method) match {
87-
case None => Right(this.copy(endpoints = endpoints + (endpoint.method -> endpoint)))
88-
case Some(_) => Left(ConflictingMethodError(endpoint.method))
87+
def mountEndpoint(endpoint: Endpoint[F]): Either[ConflictingMethodError, State[F]] = {
88+
import endpoint.method
89+
if (method.contains("*")) {
90+
val parts = method
91+
.split("\\*", -1)
92+
.map { // Don't discard trailing empty string, if any.
93+
case "" => ""
94+
case str => Pattern.quote(str)
95+
}
96+
val glob = Pattern.compile(parts.mkString(".*"))
97+
Right(this.copy(globEndpoints = globEndpoints :+ (glob -> endpoint)))
98+
} else {
99+
endpoints.get(endpoint.method) match {
100+
case None => Right(this.copy(endpoints = endpoints + (endpoint.method -> endpoint)))
101+
case Some(_) => Left(ConflictingMethodError(endpoint.method))
102+
}
89103
}
104+
}
105+
def getEndpoint(method: String): Option[Endpoint[F]] = {
106+
endpoints.get(method).orElse(globEndpoints.find(_._1.matcher(method).matches()).map(_._2))
107+
}
90108
def removeEndpoint(method: String): State[F] =
91109
copy(endpoints = endpoints.removed(method))
92110

@@ -135,7 +153,7 @@ object FS2Channel {
135153
}
136154
}
137155
protected def reportError(params: Option[Payload], error: ProtocolError, method: String): F[Unit] = ???
138-
protected def getEndpoint(method: String): F[Option[Endpoint[F]]] = state.get.map(_.endpoints.get(method))
156+
protected def getEndpoint(method: String): F[Option[Endpoint[F]]] = state.get.map(_.getEndpoint(method))
139157
protected def sendMessage(message: Message): F[Unit] = queue.offer(Codec.encode(message))
140158

141159
protected def nextCallId(): F[CallId] = state.modify(_.nextCallId)

fs2/test/src/jsonrpclib/fs2/FS2ChannelSpec.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,32 @@ object FS2ChannelSpec extends SimpleIOSuite {
5454
}
5555
}
5656

57+
testRes("Round trip (glob)") {
58+
val endpoint: Endpoint[IO] = Endpoint[IO]("**").simple((int: IntWrapper) => IO(IntWrapper(int.int + 1)))
59+
60+
for {
61+
clientSideChannel <- setup(endpoint)
62+
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc/test")
63+
result <- remoteFunction(IntWrapper(1)).toStream
64+
} yield {
65+
expect.same(result, IntWrapper(2))
66+
}
67+
}
68+
69+
testRes("Globs have lower priority than strict endpoints") {
70+
val endpoint: Endpoint[IO] = Endpoint[IO]("inc").simple((int: IntWrapper) => IO(IntWrapper(int.int + 1)))
71+
val globEndpoint: Endpoint[IO] =
72+
Endpoint[IO]("**").simple((_: IntWrapper) => IO.raiseError[IntWrapper](new Throwable("Boom")))
73+
74+
for {
75+
clientSideChannel <- setup(globEndpoint, endpoint)
76+
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc")
77+
result <- remoteFunction(IntWrapper(1)).toStream
78+
} yield {
79+
expect.same(result, IntWrapper(2))
80+
}
81+
}
82+
5783
testRes("Endpoint not mounted") {
5884

5985
for {

0 commit comments

Comments
 (0)