7
7
import httpx
8
8
from anyio .abc import TaskStatus
9
9
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
10
+ from exceptiongroup import ExceptionGroup , catch
10
11
from httpx_sse import aconnect_sse
11
12
12
13
import mcp .types as types
@@ -18,6 +19,14 @@ def remove_request_params(url: str) -> str:
18
19
return urljoin (url , urlparse (url ).path )
19
20
20
21
22
+ def handle_exception (exc : Exception ) -> str :
23
+ """Handle ExceptionGroup and Exceptions for Client transport for SSE"""
24
+ if isinstance (exc , ExceptionGroup ):
25
+ messages = "; " .join (str (e ) for e in exc .exceptions )
26
+ raise Exception (f"TaskGroup failed with: { messages } " ) from None
27
+ else :
28
+ raise Exception (f"TaskGroup failed with: { exc } " ) from None
29
+
21
30
@asynccontextmanager
22
31
async def sse_client (
23
32
url : str ,
@@ -40,115 +49,115 @@ async def sse_client(
40
49
read_stream_writer , read_stream = anyio .create_memory_object_stream (0 )
41
50
write_stream , write_stream_reader = anyio .create_memory_object_stream (0 )
42
51
43
- errors : list [Exception ] = []
44
-
45
- async with anyio .create_task_group () as tg :
46
- try :
47
- logger .info (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
48
- async with httpx .AsyncClient (headers = headers ) as client :
49
- async with aconnect_sse (
50
- client ,
51
- "GET" ,
52
- url ,
53
- timeout = httpx .Timeout (timeout , read = sse_read_timeout ),
54
- ) as event_source :
55
- event_source .response .raise_for_status ()
56
- logger .debug ("SSE connection established" )
57
-
58
- async def sse_reader (
59
- task_status : TaskStatus [str ] = anyio .TASK_STATUS_IGNORED ,
60
- ):
61
- try :
62
- async for sse in event_source .aiter_sse ():
63
- logger .debug (f"Received SSE event: { sse .event } " )
64
- match sse .event :
65
- case "endpoint" :
66
- endpoint_url = urljoin (url , sse .data )
67
- logger .info (
68
- f"Received endpoint URL: { endpoint_url } "
69
- )
70
-
71
- url_parsed = urlparse (url )
72
- endpoint_parsed = urlparse (endpoint_url )
73
- if (
74
- url_parsed .netloc != endpoint_parsed .netloc
75
- or url_parsed .scheme
76
- != endpoint_parsed .scheme
77
- ):
78
- error_msg = (
79
- "Endpoint origin does not match "
80
- f"connection origin: { endpoint_url } "
52
+ with catch ({
53
+ Exception : handle_exception ,
54
+ }):
55
+ async with anyio .create_task_group () as tg :
56
+ try :
57
+ logger .info (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
58
+ async with httpx .AsyncClient (headers = headers ) as client :
59
+ async with aconnect_sse (
60
+ client ,
61
+ "GET" ,
62
+ url ,
63
+ timeout = httpx .Timeout (timeout , read = sse_read_timeout ),
64
+ ) as event_source :
65
+ event_source .response .raise_for_status ()
66
+ logger .debug ("SSE connection established" )
67
+
68
+ async def sse_reader (
69
+ task_status : TaskStatus [str ] = anyio .TASK_STATUS_IGNORED ,
70
+ ):
71
+ try :
72
+ async for sse in event_source .aiter_sse ():
73
+ logger .debug (f"Received SSE event: { sse .event } " )
74
+ match sse .event :
75
+ case "endpoint" :
76
+ endpoint_url = urljoin (url , sse .data )
77
+ logger .info (
78
+ f"Received endpoint URL: { endpoint_url } "
81
79
)
82
- logger .error (error_msg )
83
- raise ValueError (error_msg )
84
-
85
- task_status .started (endpoint_url )
86
80
87
- case "message" :
88
- try :
89
- message = types .JSONRPCMessage .model_validate_json ( # noqa: E501
90
- sse .data
91
- )
92
- logger .debug (
93
- f"Received server message: { message } "
81
+ url_parsed = urlparse (url )
82
+ endpoint_parsed = urlparse (endpoint_url )
83
+ if (
84
+ url_parsed .netloc
85
+ != endpoint_parsed .netloc
86
+ or url_parsed .scheme
87
+ != endpoint_parsed .scheme
88
+ ):
89
+ error_msg = (
90
+ "Endpoint origin does not match "
91
+ f"connection origin: { endpoint_url } "
92
+ )
93
+ logger .error (error_msg )
94
+ raise ValueError (error_msg )
95
+
96
+ task_status .started (endpoint_url )
97
+
98
+ case "message" :
99
+ try :
100
+ message = types .JSONRPCMessage .model_validate_json ( # noqa: E501
101
+ sse .data
102
+ )
103
+ logger .debug (
104
+ f"Received server message: "
105
+ f"{ message } "
106
+ )
107
+ except Exception as exc :
108
+ logger .error (
109
+ f"Error parsing server message: "
110
+ f"{ exc } "
111
+ )
112
+ await read_stream_writer .send (exc )
113
+ continue
114
+
115
+ await read_stream_writer .send (message )
116
+ case _:
117
+ logger .warning (
118
+ f"Unknown SSE event: { sse .event } "
94
119
)
95
- except Exception as exc :
96
- logger .error (
97
- f"Error parsing server message: { exc } "
98
- )
99
- await read_stream_writer .send (exc )
100
- continue
101
-
102
- await read_stream_writer .send (message )
103
- case _:
104
- logger .warning (
105
- f"Unknown SSE event: { sse .event } "
120
+ except Exception as exc :
121
+ logger .error (f"Error in sse_reader: { exc } " )
122
+ raise
123
+ finally :
124
+ await read_stream_writer .aclose ()
125
+
126
+ async def post_writer (endpoint_url : str ):
127
+ try :
128
+ async with write_stream_reader :
129
+ async for message in write_stream_reader :
130
+ logger .debug (
131
+ f"Sending client message: { message } "
106
132
)
107
- except Exception as exc :
108
- logger .error (f"Error in sse_reader: { exc } " )
109
- raise
110
- finally :
111
- await read_stream_writer .aclose ()
133
+ response = await client .post (
134
+ endpoint_url ,
135
+ json = message .model_dump (
136
+ by_alias = True ,
137
+ mode = "json" ,
138
+ exclude_none = True ,
139
+ ),
140
+ )
141
+ response .raise_for_status ()
142
+ logger .debug (
143
+ "Client message sent successfully: "
144
+ f"{ response .status_code } "
145
+ )
146
+ except Exception as exc :
147
+ logger .error (f"Error in post_writer: { exc } " )
148
+ finally :
149
+ await write_stream .aclose ()
150
+
151
+ endpoint_url = await tg .start (sse_reader )
152
+ logger .info (
153
+ f"Starting post writer with endpoint URL: { endpoint_url } "
154
+ )
155
+ tg .start_soon (post_writer , endpoint_url )
112
156
113
- async def post_writer (endpoint_url : str ):
114
157
try :
115
- async with write_stream_reader :
116
- async for message in write_stream_reader :
117
- logger .debug (f"Sending client message: { message } " )
118
- response = await client .post (
119
- endpoint_url ,
120
- json = message .model_dump (
121
- by_alias = True ,
122
- mode = "json" ,
123
- exclude_none = True ,
124
- ),
125
- )
126
- response .raise_for_status ()
127
- logger .debug (
128
- "Client message sent successfully: "
129
- f"{ response .status_code } "
130
- )
131
- except Exception as exc :
132
- logger .error (f"Error in post_writer: { exc } " )
158
+ yield read_stream , write_stream
133
159
finally :
134
- await write_stream .aclose ()
135
-
136
- endpoint_url = await tg .start (sse_reader )
137
- logger .info (
138
- f"Starting post writer with endpoint URL: { endpoint_url } "
139
- )
140
- tg .start_soon (post_writer , endpoint_url )
141
-
142
- try :
143
- yield read_stream , write_stream
144
- finally :
145
- tg .cancel_scope .cancel ()
146
- except* ValueError as eg :
147
- errors .extend (eg .exceptions )
148
- except* Exception as eg :
149
- errors .extend (eg .exceptions )
150
- finally :
151
- await read_stream_writer .aclose ()
152
- await write_stream .aclose ()
153
- if errors :
154
- raise Exception ("TaskGroup failed with: " + " " .join ([str (e ) for e in errors ]))
160
+ tg .cancel_scope .cancel ()
161
+ finally :
162
+ await read_stream_writer .aclose ()
163
+ await write_stream .aclose ()
0 commit comments