@@ -30,6 +30,7 @@ class ServerSession(
30
30
]
31
31
):
32
32
_initialized : InitializationState = InitializationState .NotInitialized
33
+ _client_params : types .InitializeRequestParams | None = None
33
34
34
35
def __init__ (
35
36
self ,
@@ -43,12 +44,47 @@ def __init__(
43
44
self ._initialization_state = InitializationState .NotInitialized
44
45
self ._init_options = init_options
45
46
47
+ @property
48
+ def client_params (self ) -> types .InitializeRequestParams | None :
49
+ return self ._client_params
50
+
51
+ def check_client_capability (self , capability : types .ClientCapabilities ) -> bool :
52
+ """Check if the client supports a specific capability."""
53
+ if self ._client_params is None :
54
+ return False
55
+
56
+ # Get client capabilities from initialization params
57
+ client_caps = self ._client_params .capabilities
58
+
59
+ # Check each specified capability in the passed in capability object
60
+ if capability .roots is not None :
61
+ if client_caps .roots is None :
62
+ return False
63
+ if capability .roots .listChanged and not client_caps .roots .listChanged :
64
+ return False
65
+
66
+ if capability .sampling is not None :
67
+ if client_caps .sampling is None :
68
+ return False
69
+
70
+ if capability .experimental is not None :
71
+ if client_caps .experimental is None :
72
+ return False
73
+ # Check each experimental capability
74
+ for exp_key , exp_value in capability .experimental .items ():
75
+ if (exp_key not in client_caps .experimental or
76
+ client_caps .experimental [exp_key ] != exp_value ):
77
+ return False
78
+
79
+ return True
80
+
46
81
async def _received_request (
47
82
self , responder : RequestResponder [types .ClientRequest , types .ServerResult ]
48
83
):
49
84
match responder .request .root :
50
- case types .InitializeRequest ():
85
+ case types .InitializeRequest (params = params ):
51
86
self ._initialization_state = InitializationState .Initializing
87
+ self ._client_params = params
52
88
await responder .respond (
53
89
types .ServerResult (
54
90
types .InitializeResult (
@@ -81,6 +117,7 @@ async def _received_notification(
81
117
"Received notification before initialization was complete"
82
118
)
83
119
120
+
84
121
async def send_log_message (
85
122
self , level : types .LoggingLevel , data : Any , logger : str | None = None
86
123
) -> None :
0 commit comments