Skip to content

feat: add client capability checking to ServerSession #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ServerSession(
]
):
_initialized: InitializationState = InitializationState.NotInitialized
_client_params: types.InitializeRequestParams | None = None

def __init__(
self,
Expand All @@ -43,12 +44,47 @@ def __init__(
self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options

@property
def client_params(self) -> types.InitializeRequestParams | None:
return self._client_params

def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
"""Check if the client supports a specific capability."""
if self._client_params is None:
return False

# Get client capabilities from initialization params
client_caps = self._client_params.capabilities

# Check each specified capability in the passed in capability object
if capability.roots is not None:
if client_caps.roots is None:
return False
if capability.roots.listChanged and not client_caps.roots.listChanged:
return False

if capability.sampling is not None:
if client_caps.sampling is None:
return False

if capability.experimental is not None:
if client_caps.experimental is None:
return False
# Check each experimental capability
for exp_key, exp_value in capability.experimental.items():
if (exp_key not in client_caps.experimental or
client_caps.experimental[exp_key] != exp_value):
return False

return True

async def _received_request(
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
):
match responder.request.root:
case types.InitializeRequest():
case types.InitializeRequest(params=params):
self._initialization_state = InitializationState.Initializing
self._client_params = params
await responder.respond(
types.ServerResult(
types.InitializeResult(
Expand Down Expand Up @@ -81,6 +117,7 @@ async def _received_notification(
"Received notification before initialization was complete"
)


async def send_log_message(
self, level: types.LoggingLevel, data: Any, logger: str | None = None
) -> None:
Expand Down