1
+ import logging
1
2
from contextlib import AbstractAsyncContextManager
2
3
from datetime import timedelta
3
- from typing import Generic , TypeVar
4
+ from typing import Any , Callable , Generic , TypeVar
4
5
5
6
import anyio
6
7
import anyio .lowlevel
10
11
11
12
from mcp .shared .exceptions import McpError
12
13
from mcp .types import (
14
+ CancelledNotification ,
13
15
ClientNotification ,
14
16
ClientRequest ,
15
17
ClientResult ,
38
40
39
41
40
42
class RequestResponder (Generic [ReceiveRequestT , SendResultT ]):
43
+ """Handles responding to MCP requests and manages request lifecycle.
44
+
45
+ This class MUST be used as a context manager to ensure proper cleanup and
46
+ cancellation handling:
47
+
48
+ Example:
49
+ with request_responder as resp:
50
+ await resp.respond(result)
51
+
52
+ The context manager ensures:
53
+ 1. Proper cancellation scope setup and cleanup
54
+ 2. Request completion tracking
55
+ 3. Cleanup of in-flight requests
56
+ """
57
+
41
58
def __init__ (
42
59
self ,
43
60
request_id : RequestId ,
44
61
request_meta : RequestParams .Meta | None ,
45
62
request : ReceiveRequestT ,
46
63
session : "BaseSession" ,
64
+ on_complete : Callable [["RequestResponder[ReceiveRequestT, SendResultT]" ], Any ],
47
65
) -> None :
48
66
self .request_id = request_id
49
67
self .request_meta = request_meta
50
68
self .request = request
51
69
self ._session = session
52
- self ._responded = False
70
+ self ._completed = False
71
+ self ._cancel_scope = anyio .CancelScope ()
72
+ self ._on_complete = on_complete
73
+ self ._entered = False # Track if we're in a context manager
74
+
75
+ def __enter__ (self ) -> "RequestResponder[ReceiveRequestT, SendResultT]" :
76
+ """Enter the context manager, enabling request cancellation tracking."""
77
+ self ._entered = True
78
+ self ._cancel_scope = anyio .CancelScope ()
79
+ self ._cancel_scope .__enter__ ()
80
+ return self
81
+
82
+ def __exit__ (self , exc_type , exc_val , exc_tb ) -> None :
83
+ """Exit the context manager, performing cleanup and notifying completion."""
84
+ try :
85
+ if self ._completed :
86
+ self ._on_complete (self )
87
+ finally :
88
+ self ._entered = False
89
+ if not self ._cancel_scope :
90
+ raise RuntimeError ("No active cancel scope" )
91
+ self ._cancel_scope .__exit__ (exc_type , exc_val , exc_tb )
53
92
54
93
async def respond (self , response : SendResultT | ErrorData ) -> None :
55
- assert not self ._responded , "Request already responded to"
56
- self ._responded = True
94
+ """Send a response for this request.
95
+
96
+ Must be called within a context manager block.
97
+ Raises:
98
+ RuntimeError: If not used within a context manager
99
+ AssertionError: If request was already responded to
100
+ """
101
+ if not self ._entered :
102
+ raise RuntimeError ("RequestResponder must be used as a context manager" )
103
+ assert not self ._completed , "Request already responded to"
104
+
105
+ if not self .cancelled :
106
+ self ._completed = True
107
+
108
+ await self ._session ._send_response (
109
+ request_id = self .request_id , response = response
110
+ )
111
+
112
+ async def cancel (self ) -> None :
113
+ """Cancel this request and mark it as completed."""
114
+ if not self ._entered :
115
+ raise RuntimeError ("RequestResponder must be used as a context manager" )
116
+ if not self ._cancel_scope :
117
+ raise RuntimeError ("No active cancel scope" )
57
118
119
+ self ._cancel_scope .cancel ()
120
+ self ._completed = True # Mark as completed so it's removed from in_flight
121
+ # Send an error response to indicate cancellation
58
122
await self ._session ._send_response (
59
- request_id = self .request_id , response = response
123
+ request_id = self .request_id ,
124
+ response = ErrorData (code = 0 , message = "Request cancelled" , data = None ),
60
125
)
61
126
127
+ @property
128
+ def in_flight (self ) -> bool :
129
+ return not self ._completed and not self .cancelled
130
+
131
+ @property
132
+ def cancelled (self ) -> bool :
133
+ return self ._cancel_scope is not None and self ._cancel_scope .cancel_called
134
+
62
135
63
136
class BaseSession (
64
137
AbstractAsyncContextManager ,
@@ -82,6 +155,7 @@ class BaseSession(
82
155
RequestId , MemoryObjectSendStream [JSONRPCResponse | JSONRPCError ]
83
156
]
84
157
_request_id : int
158
+ _in_flight : dict [RequestId , RequestResponder [ReceiveRequestT , SendResultT ]]
85
159
86
160
def __init__ (
87
161
self ,
@@ -99,6 +173,7 @@ def __init__(
99
173
self ._receive_request_type = receive_request_type
100
174
self ._receive_notification_type = receive_notification_type
101
175
self ._read_timeout_seconds = read_timeout_seconds
176
+ self ._in_flight = {}
102
177
103
178
self ._incoming_message_stream_writer , self ._incoming_message_stream_reader = (
104
179
anyio .create_memory_object_stream [
@@ -219,27 +294,45 @@ async def _receive_loop(self) -> None:
219
294
by_alias = True , mode = "json" , exclude_none = True
220
295
)
221
296
)
297
+
222
298
responder = RequestResponder (
223
299
request_id = message .root .id ,
224
300
request_meta = validated_request .root .params .meta
225
301
if validated_request .root .params
226
302
else None ,
227
303
request = validated_request ,
228
304
session = self ,
305
+ on_complete = lambda r : self ._in_flight .pop (r .request_id , None ),
229
306
)
230
307
308
+ self ._in_flight [responder .request_id ] = responder
231
309
await self ._received_request (responder )
232
- if not responder ._responded :
310
+ if not responder ._completed :
233
311
await self ._incoming_message_stream_writer .send (responder )
312
+
234
313
elif isinstance (message .root , JSONRPCNotification ):
235
- notification = self ._receive_notification_type .model_validate (
236
- message .root .model_dump (
237
- by_alias = True , mode = "json" , exclude_none = True
314
+ try :
315
+ notification = self ._receive_notification_type .model_validate (
316
+ message .root .model_dump (
317
+ by_alias = True , mode = "json" , exclude_none = True
318
+ )
319
+ )
320
+ # Handle cancellation notifications
321
+ if isinstance (notification .root , CancelledNotification ):
322
+ cancelled_id = notification .root .params .requestId
323
+ if cancelled_id in self ._in_flight :
324
+ await self ._in_flight [cancelled_id ].cancel ()
325
+ else :
326
+ await self ._received_notification (notification )
327
+ await self ._incoming_message_stream_writer .send (
328
+ notification
329
+ )
330
+ except Exception as e :
331
+ # For other validation errors, log and continue
332
+ logging .warning (
333
+ f"Failed to validate notification: { e } . "
334
+ f"Message was: { message .root } "
238
335
)
239
- )
240
-
241
- await self ._received_notification (notification )
242
- await self ._incoming_message_stream_writer .send (notification )
243
336
else : # Response or error
244
337
stream = self ._response_streams .pop (message .root .id , None )
245
338
if stream :
0 commit comments