1
+ from collections .abc import Awaitable , Callable
1
2
from datetime import timedelta
2
- from typing import Any , Protocol
3
+ from typing import Any , Protocol , TypeAlias
3
4
4
5
import anyio .lowlevel
5
6
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
7
+ from jsonschema import ValidationError , validate
6
8
from pydantic import AnyUrl , TypeAdapter
7
9
8
10
import mcp .types as types
11
13
from mcp .shared .session import BaseSession , ProgressFnT , RequestResponder
12
14
from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
13
15
16
+
17
+ class ToolOutputValidator :
18
+ async def validate (
19
+ self , request : types .CallToolRequest , result : types .CallToolResult
20
+ ) -> bool :
21
+ raise RuntimeError ("Not implemented" )
22
+
23
+
14
24
DEFAULT_CLIENT_INFO = types .Implementation (name = "mcp" , version = "0.1.0" )
15
25
16
26
@@ -77,6 +87,25 @@ async def _default_logging_callback(
77
87
pass
78
88
79
89
90
+ ToolOutputValidatorProvider : TypeAlias = Callable [
91
+ ...,
92
+ Awaitable [ToolOutputValidator ],
93
+ ]
94
+
95
+
96
+ # this bag of spanners is required in order to
97
+ # enable the client session to be parsed to the validator
98
+ async def _python_circularity_hell (arg : Any ) -> ToolOutputValidator :
99
+ # in any sane version of the universe this should never happen
100
+ # of course in any sane programming language class circularity
101
+ # dependencies shouldn't be this hard to manage
102
+ raise RuntimeError (
103
+ "Help I'm stuck in python circularity hell, please send biscuits"
104
+ )
105
+
106
+
107
+ _default_tool_output_validator : ToolOutputValidatorProvider = _python_circularity_hell
108
+
80
109
ClientResponse : TypeAdapter [types .ClientResult | types .ErrorData ] = TypeAdapter (
81
110
types .ClientResult | types .ErrorData
82
111
)
@@ -101,6 +130,7 @@ def __init__(
101
130
logging_callback : LoggingFnT | None = None ,
102
131
message_handler : MessageHandlerFnT | None = None ,
103
132
client_info : types .Implementation | None = None ,
133
+ tool_output_validator_provider : ToolOutputValidatorProvider | None = None ,
104
134
) -> None :
105
135
super ().__init__ (
106
136
read_stream ,
@@ -114,6 +144,7 @@ def __init__(
114
144
self ._list_roots_callback = list_roots_callback or _default_list_roots_callback
115
145
self ._logging_callback = logging_callback or _default_logging_callback
116
146
self ._message_handler = message_handler or _default_message_handler
147
+ self ._tool_output_validator_provider = tool_output_validator_provider
117
148
118
149
async def initialize (self ) -> types .InitializeResult :
119
150
sampling = types .SamplingCapability ()
@@ -154,6 +185,11 @@ async def initialize(self) -> types.InitializeResult:
154
185
)
155
186
)
156
187
188
+ tool_output_validator_provider = (
189
+ self ._tool_output_validator_provider or _default_tool_output_validator
190
+ )
191
+ self ._tool_output_validator = await tool_output_validator_provider (self )
192
+
157
193
return result
158
194
159
195
async def send_ping (self ) -> types .EmptyResult :
@@ -271,24 +307,33 @@ async def call_tool(
271
307
arguments : dict [str , Any ] | None = None ,
272
308
read_timeout_seconds : timedelta | None = None ,
273
309
progress_callback : ProgressFnT | None = None ,
310
+ validate_result : bool = True ,
274
311
) -> types .CallToolResult :
275
312
"""Send a tools/call request with optional progress callback support."""
276
313
277
- return await self .send_request (
278
- types .ClientRequest (
279
- types .CallToolRequest (
280
- method = "tools/call" ,
281
- params = types .CallToolRequestParams (
282
- name = name ,
283
- arguments = arguments ,
284
- ),
285
- )
314
+ request = types .CallToolRequest (
315
+ method = "tools/call" ,
316
+ params = types .CallToolRequestParams (
317
+ name = name ,
318
+ arguments = arguments ,
286
319
),
320
+ )
321
+
322
+ result = await self .send_request (
323
+ types .ClientRequest (request ),
287
324
types .CallToolResult ,
288
325
request_read_timeout_seconds = read_timeout_seconds ,
289
326
progress_callback = progress_callback ,
290
327
)
291
328
329
+ if validate_result :
330
+ valid = await self ._tool_output_validator .validate (request , result )
331
+
332
+ if not valid :
333
+ raise RuntimeError ("Server responded with invalid result: " f"{ result } " )
334
+ # not validating or is valid
335
+ return result
336
+
292
337
async def list_prompts (self , cursor : str | None = None ) -> types .ListPromptsResult :
293
338
"""Send a prompts/list request."""
294
339
return await self .send_request (
@@ -404,3 +449,67 @@ async def _received_notification(
404
449
await self ._logging_callback (params )
405
450
case _:
406
451
pass
452
+
453
+
454
+ class SimpleCachingToolOutputValidator (ToolOutputValidator ):
455
+ _schema_cache : dict [str , dict [str , Any ] | bool ]
456
+
457
+ def __init__ (self , session : ClientSession ):
458
+ self ._session = session
459
+ self ._schema_cache = {}
460
+ self ._refresh_cache = True
461
+
462
+ async def validate (
463
+ self , request : types .CallToolRequest , result : types .CallToolResult
464
+ ) -> bool :
465
+ if result .isError :
466
+ # allow errors to be propagated
467
+ return True
468
+ else :
469
+ if self ._refresh_cache :
470
+ await self ._refresh_schema_cache ()
471
+
472
+ schema = self ._schema_cache .get (request .params .name )
473
+
474
+ if schema is None :
475
+ raise RuntimeError (f"Unknown tool { request .params .name } " )
476
+ elif schema is False :
477
+ # no schema
478
+ # TODO add logging
479
+ return result .structuredContent is None
480
+ else :
481
+ try :
482
+ # TODO opportunity to build jsonschema.protocol.Validator
483
+ # and reuse rather than build every time
484
+ validate (result .structuredContent , schema )
485
+ return True
486
+ except ValidationError :
487
+ # TODO log this
488
+ return False
489
+
490
+ async def _refresh_schema_cache (self ):
491
+ cursor = None
492
+ first = True
493
+ while first or cursor is not None :
494
+ first = False
495
+ tools_result = await self ._session .list_tools (cursor )
496
+ for tool in tools_result .tools :
497
+ # store a flag to be able to later distinguish between
498
+ # no schema for tool and unknown tool which can't be verified
499
+ schema_or_flag = (
500
+ False if tool .outputSchema is None else tool .outputSchema
501
+ )
502
+ self ._schema_cache [tool .name ] = schema_or_flag
503
+ cursor = tools_result .nextCursor
504
+ continue
505
+
506
+ self ._refresh_cache = False
507
+
508
+
509
+ async def _escape_from_circular_python_hell (
510
+ session : ClientSession ,
511
+ ) -> ToolOutputValidator :
512
+ return SimpleCachingToolOutputValidator (session )
513
+
514
+
515
+ _default_tool_output_validator = _escape_from_circular_python_hell
0 commit comments