Skip to content

Commit ad83eea

Browse files
committed
Tidy up to follow conventions for ToolOutputValidation as used in rest of session class
1 parent 43ebe80 commit ad83eea

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

src/mcp/client/session.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,6 @@
1313
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
1414
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1515

16-
17-
class ToolOutputValidator:
18-
async def validate(
19-
self, request: types.CallToolRequest, result: types.CallToolResult
20-
) -> bool:
21-
raise RuntimeError("Not implemented")
22-
23-
2416
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
2517

2618

@@ -54,6 +46,12 @@ async def __call__(
5446
) -> None: ...
5547

5648

49+
class ToolOutputValidationFnT(Protocol):
50+
async def __call__(
51+
self, request: types.CallToolRequest, result: types.CallToolResult
52+
) -> bool: ...
53+
54+
5755
async def _default_message_handler(
5856
message: RequestResponder[types.ServerRequest, types.ClientResult]
5957
| types.ServerNotification
@@ -89,13 +87,13 @@ async def _default_logging_callback(
8987

9088
ToolOutputValidatorProvider: TypeAlias = Callable[
9189
...,
92-
Awaitable[ToolOutputValidator],
90+
Awaitable[ToolOutputValidationFnT],
9391
]
9492

9593

9694
# this bag of spanners is required in order to
9795
# enable the client session to be parsed to the validator
98-
async def _python_circularity_hell(arg: Any) -> ToolOutputValidator:
96+
async def _python_circularity_hell(arg: Any) -> ToolOutputValidationFnT:
9997
# in any sane version of the universe this should never happen
10098
# of course in any sane programming language class circularity
10199
# dependencies shouldn't be this hard to manage
@@ -327,7 +325,7 @@ async def call_tool(
327325
)
328326

329327
if validate_result:
330-
valid = await self._tool_output_validator.validate(request, result)
328+
valid = await self._tool_output_validator(request, result)
331329

332330
if not valid:
333331
raise RuntimeError("Server responded with invalid result: " f"{result}")
@@ -451,15 +449,15 @@ async def _received_notification(
451449
pass
452450

453451

454-
class SimpleCachingToolOutputValidator(ToolOutputValidator):
452+
class SimpleCachingToolOutputValidator(ToolOutputValidationFnT):
455453
_schema_cache: dict[str, dict[str, Any] | bool]
456454

457455
def __init__(self, session: ClientSession):
458456
self._session = session
459457
self._schema_cache = {}
460458
self._refresh_cache = True
461459

462-
async def validate(
460+
async def __call__(
463461
self, request: types.CallToolRequest, result: types.CallToolResult
464462
) -> bool:
465463
if result.isError:
@@ -508,7 +506,7 @@ async def _refresh_schema_cache(self):
508506

509507
async def _escape_from_circular_python_hell(
510508
session: ClientSession,
511-
) -> ToolOutputValidator:
509+
) -> ToolOutputValidationFnT:
512510
return SimpleCachingToolOutputValidator(session)
513511

514512

0 commit comments

Comments
 (0)