1
+ import { ZodLiteral , ZodObject , z } from "zod" ;
1
2
import {
3
+ ClientRequestSchema ,
2
4
ErrorCode ,
3
5
JSONRPCError ,
4
6
JSONRPCNotification ,
5
7
JSONRPCRequest ,
6
8
JSONRPCResponse ,
7
9
McpError ,
8
10
Notification ,
11
+ type NotificationSchema ,
9
12
PingRequestSchema ,
10
13
Progress ,
11
14
ProgressNotification ,
12
15
ProgressNotificationSchema ,
13
16
Request ,
17
+ type RequestSchema ,
14
18
Result ,
19
+ type ResultSchema ,
20
+ ServerRequestSchema ,
15
21
} from "../types.js" ;
16
22
import { Transport } from "./transport.js" ;
17
23
@@ -25,9 +31,12 @@ export type ProgressCallback = (progress: Progress) => void;
25
31
* features like request/response linking, notifications, and progress.
26
32
*/
27
33
export class Protocol <
28
- ReceiveRequestT extends Request ,
29
- ReceiveNotificationT extends Notification ,
30
- ReceiveResultT extends Result ,
34
+ ReceiveRequestSchemaT =
35
+ | typeof ClientRequestSchema
36
+ | typeof ServerRequestSchema ,
37
+ ReceiveNotificationSchemaT extends typeof NotificationSchema &
38
+ ZodObject < { method : ZodLiteral < string > } > ,
39
+ ReceiveResultSchemaT extends typeof ResultSchema ,
31
40
SendRequestT extends Request ,
32
41
SendNotificationT extends Notification ,
33
42
SendResultT extends Result ,
@@ -36,15 +45,15 @@ export class Protocol<
36
45
private _requestMessageId = 0 ;
37
46
private _requestHandlers : Map <
38
47
string ,
39
- ( request : ReceiveRequestT ) => Promise < SendResultT >
48
+ ( request : JSONRPCRequest ) => Promise < SendResultT >
40
49
> = new Map ( ) ;
41
50
private _notificationHandlers : Map <
42
51
string ,
43
- ( notification : ReceiveNotificationT ) => Promise < void >
52
+ ( notification : JSONRPCNotification ) => Promise < void >
44
53
> = new Map ( ) ;
45
54
private _responseHandlers : Map <
46
55
number ,
47
- ( response : ReceiveResultT | Error ) => void
56
+ ( response : JSONRPCResponse | Error ) => void
48
57
> = new Map ( ) ;
49
58
private _progressHandlers : Map < number , ProgressCallback > = new Map ( ) ;
50
59
@@ -65,25 +74,24 @@ export class Protocol<
65
74
/**
66
75
* A handler to invoke for any request types that do not have their own handler installed.
67
76
*/
68
- fallbackRequestHandler ?: ( request : ReceiveRequestT ) => Promise < SendResultT > ;
77
+ fallbackRequestHandler ?: (
78
+ request : z . infer < ReceiveRequestSchemaT > ,
79
+ ) => Promise < SendResultT > ;
69
80
70
81
/**
71
82
* A handler to invoke for any notification types that do not have their own handler installed.
72
83
*/
73
84
fallbackNotificationHandler ?: (
74
- notification : ReceiveNotificationT ,
85
+ notification : z . infer < ReceiveNotificationSchemaT > ,
75
86
) => Promise < void > ;
76
87
77
88
constructor ( ) {
78
- this . setNotificationHandler (
79
- ProgressNotificationSchema . shape . method . value ,
80
- ( notification ) => {
81
- this . _onprogress ( notification as unknown as ProgressNotification ) ;
82
- } ,
83
- ) ;
89
+ this . setNotificationHandler ( ProgressNotificationSchema , ( notification ) => {
90
+ this . _onprogress ( notification as unknown as ProgressNotification ) ;
91
+ } ) ;
84
92
85
93
this . setRequestHandler (
86
- PingRequestSchema . shape . method . value ,
94
+ PingRequestSchema ,
87
95
// Automatic pong by default.
88
96
( _request ) => ( { } ) as SendResultT ,
89
97
) ;
@@ -106,11 +114,11 @@ export class Protocol<
106
114
107
115
this . _transport . onmessage = ( message ) => {
108
116
if ( ! ( "method" in message ) ) {
109
- this . _onresponse ( message as JSONRPCResponse | JSONRPCError ) ;
117
+ this . _onresponse ( message ) ;
110
118
} else if ( "id" in message ) {
111
- this . _onrequest ( message as JSONRPCRequest ) ;
119
+ this . _onrequest ( message ) ;
112
120
} else {
113
- this . _onnotification ( message as JSONRPCNotification ) ;
121
+ this . _onnotification ( message ) ;
114
122
}
115
123
} ;
116
124
}
@@ -142,7 +150,7 @@ export class Protocol<
142
150
return ;
143
151
}
144
152
145
- handler ( notification as unknown as ReceiveNotificationT ) . catch ( ( error ) =>
153
+ handler ( notification ) . catch ( ( error ) =>
146
154
this . _onerror (
147
155
new Error ( `Uncaught error in notification handler: ${ error } ` ) ,
148
156
) ,
@@ -171,7 +179,7 @@ export class Protocol<
171
179
return ;
172
180
}
173
181
174
- handler ( request as unknown as ReceiveRequestT )
182
+ handler ( request )
175
183
. then (
176
184
( result ) => {
177
185
this . _transport ?. send ( {
@@ -228,7 +236,7 @@ export class Protocol<
228
236
this . _responseHandlers . delete ( Number ( messageId ) ) ;
229
237
this . _progressHandlers . delete ( Number ( messageId ) ) ;
230
238
if ( "result" in response ) {
231
- handler ( response . result as ReceiveResultT ) ;
239
+ handler ( response ) ;
232
240
} else {
233
241
const error = new McpError (
234
242
response . error . code ,
@@ -256,10 +264,11 @@ export class Protocol<
256
264
* Do not use this method to emit notifications! Use notification() instead.
257
265
*/
258
266
// TODO: This could infer a better response type based on the method
259
- request (
267
+ request < T extends ReceiveResultSchemaT > (
260
268
request : SendRequestT ,
269
+ resultSchema : T ,
261
270
onprogress ?: ProgressCallback ,
262
- ) : Promise < ReceiveResultT > {
271
+ ) : Promise < z . infer < T > > {
263
272
return new Promise ( ( resolve , reject ) => {
264
273
if ( ! this . _transport ) {
265
274
reject ( new Error ( "Not connected" ) ) ;
@@ -314,13 +323,12 @@ export class Protocol<
314
323
*
315
324
* Note that this will replace any previous request handler for the same method.
316
325
*/
317
- // TODO: This could infer a better request type based on the method.
318
- setRequestHandler (
319
- method : string ,
320
- handler : ( request : ReceiveRequestT ) => SendResultT | Promise < SendResultT > ,
326
+ setRequestHandler < T extends ReceiveRequestSchemaT > (
327
+ requestSchema : T ,
328
+ handler : ( request : z . infer < T > ) => SendResultT | Promise < SendResultT > ,
321
329
) : void {
322
- this . _requestHandlers . set ( method , ( request ) =>
323
- Promise . resolve ( handler ( request ) ) ,
330
+ this . _requestHandlers . set ( requestSchema . shape . method . value , ( request ) =>
331
+ Promise . resolve ( handler ( requestSchema . parse ( request ) ) ) ,
324
332
) ;
325
333
}
326
334
@@ -337,12 +345,14 @@ export class Protocol<
337
345
* Note that this will replace any previous notification handler for the same method.
338
346
*/
339
347
// TODO: This could infer a better notification type based on the method.
340
- setNotificationHandler < T extends ReceiveNotificationT > (
341
- method : string ,
342
- handler : ( notification : T ) => void | Promise < void > ,
348
+ setNotificationHandler < T extends ReceiveNotificationSchemaT > (
349
+ notificationSchema : T ,
350
+ handler : ( notification : z . infer < T > ) => void | Promise < void > ,
343
351
) : void {
344
- this . _notificationHandlers . set ( method , ( notification ) =>
345
- Promise . resolve ( handler ( notification as T ) ) ,
352
+ this . _notificationHandlers . set (
353
+ notificationSchema . shape . method . value ,
354
+ ( notification ) =>
355
+ Promise . resolve ( handler ( notificationSchema . parse ( notification ) ) ) ,
346
356
) ;
347
357
}
348
358
0 commit comments