Skip to content

Commit bb3894a

Browse files
authored
Merge pull request #45 from modelcontextprotocol/davidsp/client-capabilities
feat: add client capability checking to ServerSession
2 parents df33a9b + 76a0b80 commit bb3894a

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

src/mcp/server/session.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ServerSession(
3030
]
3131
):
3232
_initialized: InitializationState = InitializationState.NotInitialized
33+
_client_params: types.InitializeRequestParams | None = None
3334

3435
def __init__(
3536
self,
@@ -43,12 +44,47 @@ def __init__(
4344
self._initialization_state = InitializationState.NotInitialized
4445
self._init_options = init_options
4546

47+
@property
48+
def client_params(self) -> types.InitializeRequestParams | None:
49+
return self._client_params
50+
51+
def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
52+
"""Check if the client supports a specific capability."""
53+
if self._client_params is None:
54+
return False
55+
56+
# Get client capabilities from initialization params
57+
client_caps = self._client_params.capabilities
58+
59+
# Check each specified capability in the passed in capability object
60+
if capability.roots is not None:
61+
if client_caps.roots is None:
62+
return False
63+
if capability.roots.listChanged and not client_caps.roots.listChanged:
64+
return False
65+
66+
if capability.sampling is not None:
67+
if client_caps.sampling is None:
68+
return False
69+
70+
if capability.experimental is not None:
71+
if client_caps.experimental is None:
72+
return False
73+
# Check each experimental capability
74+
for exp_key, exp_value in capability.experimental.items():
75+
if (exp_key not in client_caps.experimental or
76+
client_caps.experimental[exp_key] != exp_value):
77+
return False
78+
79+
return True
80+
4681
async def _received_request(
4782
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
4883
):
4984
match responder.request.root:
50-
case types.InitializeRequest():
85+
case types.InitializeRequest(params=params):
5186
self._initialization_state = InitializationState.Initializing
87+
self._client_params = params
5288
await responder.respond(
5389
types.ServerResult(
5490
types.InitializeResult(
@@ -81,6 +117,7 @@ async def _received_notification(
81117
"Received notification before initialization was complete"
82118
)
83119

120+
84121
async def send_log_message(
85122
self, level: types.LoggingLevel, data: Any, logger: str | None = None
86123
) -> None:

0 commit comments

Comments
 (0)