diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index a430533b3..a77dc7a1e 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -154,7 +154,6 @@ async def __aexit__( for exit_stack in self._session_exit_stacks.values(): tg.start_soon(exit_stack.aclose) - @property def sessions(self) -> list[mcp.ClientSession]: """Returns the list of sessions being managed.""" diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 90b4eb27c..04ae98751 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -351,26 +351,34 @@ async def _receive_loop(self) -> None: if isinstance(message, Exception): await self._handle_incoming(message) elif isinstance(message.message.root, JSONRPCRequest): - validated_request = self._receive_request_type.model_validate( - message.message.root.model_dump( - by_alias=True, mode="json", exclude_none=True + try: + validated_request = self._receive_request_type.model_validate( + message.message.root.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + ) + responder = RequestResponder( + request_id=message.message.root.id, + request_meta=validated_request.root.params.meta + if validated_request.root.params + else None, + request=validated_request, + session=self, + on_complete=lambda r: self._in_flight.pop( + r.request_id, None + ), ) - ) - responder = RequestResponder( - request_id=message.message.root.id, - request_meta=validated_request.root.params.meta - if validated_request.root.params - else None, - request=validated_request, - session=self, - on_complete=lambda r: self._in_flight.pop(r.request_id, None), - ) - self._in_flight[responder.request_id] = responder - await self._received_request(responder) + self._in_flight[responder.request_id] = responder + await self._received_request(responder) - if not responder._completed: # type: ignore[reportPrivateUsage] - await self._handle_incoming(responder) + if not responder._completed: # type: ignore[reportPrivateUsage] + await self._handle_incoming(responder) + except Exception as e: + logging.warning( + f"Failed to validate request: {e}. " + f"Message was: {message.message.root}" + ) elif isinstance(message.message.root, JSONRPCNotification): try: diff --git a/tests/shared/test_session_handles_error.py b/tests/shared/test_session_handles_error.py new file mode 100644 index 000000000..0f4a96fba --- /dev/null +++ b/tests/shared/test_session_handles_error.py @@ -0,0 +1,42 @@ +from collections.abc import AsyncGenerator + +import pytest +from anyio.streams.memory import MemoryObjectSendStream + +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.shared.session import BaseSession +from mcp.types import ClientNotification, ClientRequest, JSONRPCMessage, JSONRPCRequest + + +@pytest.fixture +async def client_write() -> ( + AsyncGenerator[MemoryObjectSendStream[SessionMessage], None] +): + """A stream that allows to write to a running session.""" + async with create_client_server_memory_streams() as ( + (_, client_write), + (server_read, server_write), + ): + async with BaseSession( + read_stream=server_read, + write_stream=server_write, + receive_request_type=ClientRequest, + receive_notification_type=ClientNotification, + ) as _: + yield client_write + + +@pytest.mark.anyio +async def test_session_does_not_raise_error_with_bad_input( + client_write: MemoryObjectSendStream[SessionMessage], +): + # Given a running session + + # When the client sends a bad request to the session + request = JSONRPCRequest(jsonrpc="2.0", id=1, method="bad_method", params=None) + message = SessionMessage(message=JSONRPCMessage(root=request)) + await client_write.send(message) + + # Then the session can still be talked to + await client_write.send(message)