From 76a0b80c4c3e48f4b549e892e706e07100ae781f Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Mon, 11 Nov 2024 21:53:34 +0000 Subject: [PATCH] feat: add client capability checking to ServerSession Add methods to track and verify client capabilities during initialization. This includes storing client parameters from the initialize request and providing a check_client_capability method to verify if specific capabilities are supported by the connected client. --- src/mcp/server/session.py | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index dd6630784..3f13240bc 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -30,6 +30,7 @@ class ServerSession( ] ): _initialized: InitializationState = InitializationState.NotInitialized + _client_params: types.InitializeRequestParams | None = None def __init__( self, @@ -43,12 +44,47 @@ def __init__( self._initialization_state = InitializationState.NotInitialized self._init_options = init_options + @property + def client_params(self) -> types.InitializeRequestParams | None: + return self._client_params + + def check_client_capability(self, capability: types.ClientCapabilities) -> bool: + """Check if the client supports a specific capability.""" + if self._client_params is None: + return False + + # Get client capabilities from initialization params + client_caps = self._client_params.capabilities + + # Check each specified capability in the passed in capability object + if capability.roots is not None: + if client_caps.roots is None: + return False + if capability.roots.listChanged and not client_caps.roots.listChanged: + return False + + if capability.sampling is not None: + if client_caps.sampling is None: + return False + + if capability.experimental is not None: + if client_caps.experimental is None: + return False + # Check each experimental capability + for exp_key, exp_value in capability.experimental.items(): + if (exp_key not in client_caps.experimental or + client_caps.experimental[exp_key] != exp_value): + return False + + return True + async def _received_request( self, responder: RequestResponder[types.ClientRequest, types.ServerResult] ): match responder.request.root: - case types.InitializeRequest(): + case types.InitializeRequest(params=params): self._initialization_state = InitializationState.Initializing + self._client_params = params await responder.respond( types.ServerResult( types.InitializeResult( @@ -81,6 +117,7 @@ async def _received_notification( "Received notification before initialization was complete" ) + async def send_log_message( self, level: types.LoggingLevel, data: Any, logger: str | None = None ) -> None: