Skip to content

Commit 0fdf7d4

Browse files
committed
WIP Protocol validation
1 parent 3ea255f commit 0fdf7d4

File tree

1 file changed

+44
-34
lines changed

1 file changed

+44
-34
lines changed

src/shared/protocol.ts

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
1+
import { ZodLiteral, ZodObject, z } from "zod";
12
import {
3+
ClientRequestSchema,
24
ErrorCode,
35
JSONRPCError,
46
JSONRPCNotification,
57
JSONRPCRequest,
68
JSONRPCResponse,
79
McpError,
810
Notification,
11+
type NotificationSchema,
912
PingRequestSchema,
1013
Progress,
1114
ProgressNotification,
1215
ProgressNotificationSchema,
1316
Request,
17+
type RequestSchema,
1418
Result,
19+
type ResultSchema,
20+
ServerRequestSchema,
1521
} from "../types.js";
1622
import { Transport } from "./transport.js";
1723

@@ -25,9 +31,12 @@ export type ProgressCallback = (progress: Progress) => void;
2531
* features like request/response linking, notifications, and progress.
2632
*/
2733
export class Protocol<
28-
ReceiveRequestT extends Request,
29-
ReceiveNotificationT extends Notification,
30-
ReceiveResultT extends Result,
34+
ReceiveRequestSchemaT =
35+
| typeof ClientRequestSchema
36+
| typeof ServerRequestSchema,
37+
ReceiveNotificationSchemaT extends typeof NotificationSchema &
38+
ZodObject<{ method: ZodLiteral<string> }>,
39+
ReceiveResultSchemaT extends typeof ResultSchema,
3140
SendRequestT extends Request,
3241
SendNotificationT extends Notification,
3342
SendResultT extends Result,
@@ -36,15 +45,15 @@ export class Protocol<
3645
private _requestMessageId = 0;
3746
private _requestHandlers: Map<
3847
string,
39-
(request: ReceiveRequestT) => Promise<SendResultT>
48+
(request: JSONRPCRequest) => Promise<SendResultT>
4049
> = new Map();
4150
private _notificationHandlers: Map<
4251
string,
43-
(notification: ReceiveNotificationT) => Promise<void>
52+
(notification: JSONRPCNotification) => Promise<void>
4453
> = new Map();
4554
private _responseHandlers: Map<
4655
number,
47-
(response: ReceiveResultT | Error) => void
56+
(response: JSONRPCResponse | Error) => void
4857
> = new Map();
4958
private _progressHandlers: Map<number, ProgressCallback> = new Map();
5059

@@ -65,25 +74,24 @@ export class Protocol<
6574
/**
6675
* A handler to invoke for any request types that do not have their own handler installed.
6776
*/
68-
fallbackRequestHandler?: (request: ReceiveRequestT) => Promise<SendResultT>;
77+
fallbackRequestHandler?: (
78+
request: z.infer<ReceiveRequestSchemaT>,
79+
) => Promise<SendResultT>;
6980

7081
/**
7182
* A handler to invoke for any notification types that do not have their own handler installed.
7283
*/
7384
fallbackNotificationHandler?: (
74-
notification: ReceiveNotificationT,
85+
notification: z.infer<ReceiveNotificationSchemaT>,
7586
) => Promise<void>;
7687

7788
constructor() {
78-
this.setNotificationHandler(
79-
ProgressNotificationSchema.shape.method.value,
80-
(notification) => {
81-
this._onprogress(notification as unknown as ProgressNotification);
82-
},
83-
);
89+
this.setNotificationHandler(ProgressNotificationSchema, (notification) => {
90+
this._onprogress(notification as unknown as ProgressNotification);
91+
});
8492

8593
this.setRequestHandler(
86-
PingRequestSchema.shape.method.value,
94+
PingRequestSchema,
8795
// Automatic pong by default.
8896
(_request) => ({}) as SendResultT,
8997
);
@@ -106,11 +114,11 @@ export class Protocol<
106114

107115
this._transport.onmessage = (message) => {
108116
if (!("method" in message)) {
109-
this._onresponse(message as JSONRPCResponse | JSONRPCError);
117+
this._onresponse(message);
110118
} else if ("id" in message) {
111-
this._onrequest(message as JSONRPCRequest);
119+
this._onrequest(message);
112120
} else {
113-
this._onnotification(message as JSONRPCNotification);
121+
this._onnotification(message);
114122
}
115123
};
116124
}
@@ -142,7 +150,7 @@ export class Protocol<
142150
return;
143151
}
144152

145-
handler(notification as unknown as ReceiveNotificationT).catch((error) =>
153+
handler(notification).catch((error) =>
146154
this._onerror(
147155
new Error(`Uncaught error in notification handler: ${error}`),
148156
),
@@ -171,7 +179,7 @@ export class Protocol<
171179
return;
172180
}
173181

174-
handler(request as unknown as ReceiveRequestT)
182+
handler(request)
175183
.then(
176184
(result) => {
177185
this._transport?.send({
@@ -228,7 +236,7 @@ export class Protocol<
228236
this._responseHandlers.delete(Number(messageId));
229237
this._progressHandlers.delete(Number(messageId));
230238
if ("result" in response) {
231-
handler(response.result as ReceiveResultT);
239+
handler(response);
232240
} else {
233241
const error = new McpError(
234242
response.error.code,
@@ -256,10 +264,11 @@ export class Protocol<
256264
* Do not use this method to emit notifications! Use notification() instead.
257265
*/
258266
// TODO: This could infer a better response type based on the method
259-
request(
267+
request<T extends ReceiveResultSchemaT>(
260268
request: SendRequestT,
269+
resultSchema: T,
261270
onprogress?: ProgressCallback,
262-
): Promise<ReceiveResultT> {
271+
): Promise<z.infer<T>> {
263272
return new Promise((resolve, reject) => {
264273
if (!this._transport) {
265274
reject(new Error("Not connected"));
@@ -314,13 +323,12 @@ export class Protocol<
314323
*
315324
* Note that this will replace any previous request handler for the same method.
316325
*/
317-
// TODO: This could infer a better request type based on the method.
318-
setRequestHandler(
319-
method: string,
320-
handler: (request: ReceiveRequestT) => SendResultT | Promise<SendResultT>,
326+
setRequestHandler<T extends ReceiveRequestSchemaT>(
327+
requestSchema: T,
328+
handler: (request: z.infer<T>) => SendResultT | Promise<SendResultT>,
321329
): void {
322-
this._requestHandlers.set(method, (request) =>
323-
Promise.resolve(handler(request)),
330+
this._requestHandlers.set(requestSchema.shape.method.value, (request) =>
331+
Promise.resolve(handler(requestSchema.parse(request))),
324332
);
325333
}
326334

@@ -337,12 +345,14 @@ export class Protocol<
337345
* Note that this will replace any previous notification handler for the same method.
338346
*/
339347
// TODO: This could infer a better notification type based on the method.
340-
setNotificationHandler<T extends ReceiveNotificationT>(
341-
method: string,
342-
handler: (notification: T) => void | Promise<void>,
348+
setNotificationHandler<T extends ReceiveNotificationSchemaT>(
349+
notificationSchema: T,
350+
handler: (notification: z.infer<T>) => void | Promise<void>,
343351
): void {
344-
this._notificationHandlers.set(method, (notification) =>
345-
Promise.resolve(handler(notification as T)),
352+
this._notificationHandlers.set(
353+
notificationSchema.shape.method.value,
354+
(notification) =>
355+
Promise.resolve(handler(notificationSchema.parse(notification))),
346356
);
347357
}
348358

0 commit comments

Comments
 (0)