|
13 | 13 | from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
|
14 | 14 | from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
15 | 15 |
|
16 |
| - |
17 |
| -class ToolOutputValidator: |
18 |
| - async def validate( |
19 |
| - self, request: types.CallToolRequest, result: types.CallToolResult |
20 |
| - ) -> bool: |
21 |
| - raise RuntimeError("Not implemented") |
22 |
| - |
23 |
| - |
24 | 16 | DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
|
25 | 17 |
|
26 | 18 |
|
@@ -54,6 +46,12 @@ async def __call__(
|
54 | 46 | ) -> None: ...
|
55 | 47 |
|
56 | 48 |
|
| 49 | +class ToolOutputValidationFnT(Protocol): |
| 50 | + async def __call__( |
| 51 | + self, request: types.CallToolRequest, result: types.CallToolResult |
| 52 | + ) -> bool: ... |
| 53 | + |
| 54 | + |
57 | 55 | async def _default_message_handler(
|
58 | 56 | message: RequestResponder[types.ServerRequest, types.ClientResult]
|
59 | 57 | | types.ServerNotification
|
@@ -89,13 +87,13 @@ async def _default_logging_callback(
|
89 | 87 |
|
90 | 88 | ToolOutputValidatorProvider: TypeAlias = Callable[
|
91 | 89 | ...,
|
92 |
| - Awaitable[ToolOutputValidator], |
| 90 | + Awaitable[ToolOutputValidationFnT], |
93 | 91 | ]
|
94 | 92 |
|
95 | 93 |
|
96 | 94 | # this bag of spanners is required in order to
|
97 | 95 | # enable the client session to be parsed to the validator
|
98 |
| -async def _python_circularity_hell(arg: Any) -> ToolOutputValidator: |
| 96 | +async def _python_circularity_hell(arg: Any) -> ToolOutputValidationFnT: |
99 | 97 | # in any sane version of the universe this should never happen
|
100 | 98 | # of course in any sane programming language class circularity
|
101 | 99 | # dependencies shouldn't be this hard to manage
|
@@ -327,7 +325,7 @@ async def call_tool(
|
327 | 325 | )
|
328 | 326 |
|
329 | 327 | if validate_result:
|
330 |
| - valid = await self._tool_output_validator.validate(request, result) |
| 328 | + valid = await self._tool_output_validator(request, result) |
331 | 329 |
|
332 | 330 | if not valid:
|
333 | 331 | raise RuntimeError("Server responded with invalid result: " f"{result}")
|
@@ -451,15 +449,15 @@ async def _received_notification(
|
451 | 449 | pass
|
452 | 450 |
|
453 | 451 |
|
454 |
| -class SimpleCachingToolOutputValidator(ToolOutputValidator): |
| 452 | +class SimpleCachingToolOutputValidator(ToolOutputValidationFnT): |
455 | 453 | _schema_cache: dict[str, dict[str, Any] | bool]
|
456 | 454 |
|
457 | 455 | def __init__(self, session: ClientSession):
|
458 | 456 | self._session = session
|
459 | 457 | self._schema_cache = {}
|
460 | 458 | self._refresh_cache = True
|
461 | 459 |
|
462 |
| - async def validate( |
| 460 | + async def __call__( |
463 | 461 | self, request: types.CallToolRequest, result: types.CallToolResult
|
464 | 462 | ) -> bool:
|
465 | 463 | if result.isError:
|
@@ -508,7 +506,7 @@ async def _refresh_schema_cache(self):
|
508 | 506 |
|
509 | 507 | async def _escape_from_circular_python_hell(
|
510 | 508 | session: ClientSession,
|
511 |
| -) -> ToolOutputValidator: |
| 509 | +) -> ToolOutputValidationFnT: |
512 | 510 | return SimpleCachingToolOutputValidator(session)
|
513 | 511 |
|
514 | 512 |
|
|
0 commit comments