Skip to content

Commit 43ebe80

Browse files
committed
add schema checking on client side via jsonschema
1 parent 982f6b0 commit 43ebe80

File tree

5 files changed

+935
-679
lines changed

5 files changed

+935
-679
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"sse-starlette>=1.6.1",
3232
"pydantic-settings>=2.5.2",
3333
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
34+
"jsonschema==4.23.0",
3435
]
3536

3637
[project.optional-dependencies]

src/mcp/client/session.py

Lines changed: 119 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from collections.abc import Awaitable, Callable
12
from datetime import timedelta
2-
from typing import Any, Protocol
3+
from typing import Any, Protocol, TypeAlias
34

45
import anyio.lowlevel
56
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7+
from jsonschema import ValidationError, validate
68
from pydantic import AnyUrl, TypeAdapter
79

810
import mcp.types as types
@@ -11,6 +13,14 @@
1113
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
1214
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1315

16+
17+
class ToolOutputValidator:
18+
async def validate(
19+
self, request: types.CallToolRequest, result: types.CallToolResult
20+
) -> bool:
21+
raise RuntimeError("Not implemented")
22+
23+
1424
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
1525

1626

@@ -77,6 +87,25 @@ async def _default_logging_callback(
7787
pass
7888

7989

90+
ToolOutputValidatorProvider: TypeAlias = Callable[
91+
...,
92+
Awaitable[ToolOutputValidator],
93+
]
94+
95+
96+
# this bag of spanners is required in order to
97+
# enable the client session to be parsed to the validator
98+
async def _python_circularity_hell(arg: Any) -> ToolOutputValidator:
99+
# in any sane version of the universe this should never happen
100+
# of course in any sane programming language class circularity
101+
# dependencies shouldn't be this hard to manage
102+
raise RuntimeError(
103+
"Help I'm stuck in python circularity hell, please send biscuits"
104+
)
105+
106+
107+
_default_tool_output_validator: ToolOutputValidatorProvider = _python_circularity_hell
108+
80109
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
81110
types.ClientResult | types.ErrorData
82111
)
@@ -101,6 +130,7 @@ def __init__(
101130
logging_callback: LoggingFnT | None = None,
102131
message_handler: MessageHandlerFnT | None = None,
103132
client_info: types.Implementation | None = None,
133+
tool_output_validator_provider: ToolOutputValidatorProvider | None = None,
104134
) -> None:
105135
super().__init__(
106136
read_stream,
@@ -114,6 +144,7 @@ def __init__(
114144
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
115145
self._logging_callback = logging_callback or _default_logging_callback
116146
self._message_handler = message_handler or _default_message_handler
147+
self._tool_output_validator_provider = tool_output_validator_provider
117148

118149
async def initialize(self) -> types.InitializeResult:
119150
sampling = types.SamplingCapability()
@@ -154,6 +185,11 @@ async def initialize(self) -> types.InitializeResult:
154185
)
155186
)
156187

188+
tool_output_validator_provider = (
189+
self._tool_output_validator_provider or _default_tool_output_validator
190+
)
191+
self._tool_output_validator = await tool_output_validator_provider(self)
192+
157193
return result
158194

159195
async def send_ping(self) -> types.EmptyResult:
@@ -271,24 +307,33 @@ async def call_tool(
271307
arguments: dict[str, Any] | None = None,
272308
read_timeout_seconds: timedelta | None = None,
273309
progress_callback: ProgressFnT | None = None,
310+
validate_result: bool = True,
274311
) -> types.CallToolResult:
275312
"""Send a tools/call request with optional progress callback support."""
276313

277-
return await self.send_request(
278-
types.ClientRequest(
279-
types.CallToolRequest(
280-
method="tools/call",
281-
params=types.CallToolRequestParams(
282-
name=name,
283-
arguments=arguments,
284-
),
285-
)
314+
request = types.CallToolRequest(
315+
method="tools/call",
316+
params=types.CallToolRequestParams(
317+
name=name,
318+
arguments=arguments,
286319
),
320+
)
321+
322+
result = await self.send_request(
323+
types.ClientRequest(request),
287324
types.CallToolResult,
288325
request_read_timeout_seconds=read_timeout_seconds,
289326
progress_callback=progress_callback,
290327
)
291328

329+
if validate_result:
330+
valid = await self._tool_output_validator.validate(request, result)
331+
332+
if not valid:
333+
raise RuntimeError("Server responded with invalid result: " f"{result}")
334+
# not validating or is valid
335+
return result
336+
292337
async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult:
293338
"""Send a prompts/list request."""
294339
return await self.send_request(
@@ -404,3 +449,67 @@ async def _received_notification(
404449
await self._logging_callback(params)
405450
case _:
406451
pass
452+
453+
454+
class SimpleCachingToolOutputValidator(ToolOutputValidator):
455+
_schema_cache: dict[str, dict[str, Any] | bool]
456+
457+
def __init__(self, session: ClientSession):
458+
self._session = session
459+
self._schema_cache = {}
460+
self._refresh_cache = True
461+
462+
async def validate(
463+
self, request: types.CallToolRequest, result: types.CallToolResult
464+
) -> bool:
465+
if result.isError:
466+
# allow errors to be propagated
467+
return True
468+
else:
469+
if self._refresh_cache:
470+
await self._refresh_schema_cache()
471+
472+
schema = self._schema_cache.get(request.params.name)
473+
474+
if schema is None:
475+
raise RuntimeError(f"Unknown tool {request.params.name}")
476+
elif schema is False:
477+
# no schema
478+
# TODO add logging
479+
return result.structuredContent is None
480+
else:
481+
try:
482+
# TODO opportunity to build jsonschema.protocol.Validator
483+
# and reuse rather than build every time
484+
validate(result.structuredContent, schema)
485+
return True
486+
except ValidationError:
487+
# TODO log this
488+
return False
489+
490+
async def _refresh_schema_cache(self):
491+
cursor = None
492+
first = True
493+
while first or cursor is not None:
494+
first = False
495+
tools_result = await self._session.list_tools(cursor)
496+
for tool in tools_result.tools:
497+
# store a flag to be able to later distinguish between
498+
# no schema for tool and unknown tool which can't be verified
499+
schema_or_flag = (
500+
False if tool.outputSchema is None else tool.outputSchema
501+
)
502+
self._schema_cache[tool.name] = schema_or_flag
503+
cursor = tools_result.nextCursor
504+
continue
505+
506+
self._refresh_cache = False
507+
508+
509+
async def _escape_from_circular_python_hell(
510+
session: ClientSession,
511+
) -> ToolOutputValidator:
512+
return SimpleCachingToolOutputValidator(session)
513+
514+
515+
_default_tool_output_validator = _escape_from_circular_python_hell

src/mcp/server/lowlevel/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ def handle_error(e: Exception):
487487
)
488488

489489
if schema_func is None:
490+
490491
async def handler(req: types.CallToolRequest):
491492
try:
492493
result = await func(
@@ -496,6 +497,7 @@ async def handler(req: types.CallToolRequest):
496497
except Exception as e:
497498
return handle_error(e)
498499
else:
500+
499501
async def handler(req: types.CallToolRequest):
500502
try:
501503
result = await func(

tests/issues/test_88_random_error.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
EmbeddedResource,
1616
ImageContent,
1717
TextContent,
18+
Tool,
1819
)
1920

2021

@@ -55,6 +56,10 @@ async def slow_tool(
5556
return [TextContent(type="text", text=f"fast {request_count}")]
5657
return [TextContent(type="text", text=f"unknown {request_count}")]
5758

59+
@server.list_tools()
60+
async def list_tools() -> list[Tool]:
61+
return [Tool(name="fast", inputSchema={}), Tool(name="slow", inputSchema={})]
62+
5863
async def server_handler(
5964
read_stream,
6065
write_stream,

0 commit comments

Comments
 (0)