Skip to content

Commit c2dc97f

Browse files
committed
fix: except* python 3.10 incompatibility issue
* added exceptiongroup dependency in toml for handling exception group issue from python 3.10 to python 3.13 * added handle_exception function to handle the exceptions * updated the uv.lock
1 parent c202846 commit c2dc97f

File tree

3 files changed

+118
-106
lines changed

3 files changed

+118
-106
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"sse-starlette>=1.6.1",
3131
"pydantic-settings>=2.5.2",
3232
"uvicorn>=0.23.1",
33+
"exceptiongroup>=1.2.0",
3334
]
3435

3536
[project.optional-dependencies]

src/mcp/client/sse.py

Lines changed: 114 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import httpx
88
from anyio.abc import TaskStatus
99
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10+
from exceptiongroup import ExceptionGroup, catch
1011
from httpx_sse import aconnect_sse
1112

1213
import mcp.types as types
@@ -18,6 +19,14 @@ def remove_request_params(url: str) -> str:
1819
return urljoin(url, urlparse(url).path)
1920

2021

22+
def handle_exception(exc: Exception) -> str:
23+
"""Handle ExceptionGroup and Exceptions for Client transport for SSE"""
24+
if isinstance(exc, ExceptionGroup):
25+
messages = "; ".join(str(e) for e in exc.exceptions)
26+
raise Exception(f"TaskGroup failed with: {messages}") from None
27+
else:
28+
raise Exception(f"TaskGroup failed with: {exc}") from None
29+
2130
@asynccontextmanager
2231
async def sse_client(
2332
url: str,
@@ -40,115 +49,115 @@ async def sse_client(
4049
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
4150
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
4251

43-
errors: list[Exception] = []
44-
45-
async with anyio.create_task_group() as tg:
46-
try:
47-
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
48-
async with httpx.AsyncClient(headers=headers) as client:
49-
async with aconnect_sse(
50-
client,
51-
"GET",
52-
url,
53-
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
54-
) as event_source:
55-
event_source.response.raise_for_status()
56-
logger.debug("SSE connection established")
57-
58-
async def sse_reader(
59-
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
60-
):
61-
try:
62-
async for sse in event_source.aiter_sse():
63-
logger.debug(f"Received SSE event: {sse.event}")
64-
match sse.event:
65-
case "endpoint":
66-
endpoint_url = urljoin(url, sse.data)
67-
logger.info(
68-
f"Received endpoint URL: {endpoint_url}"
69-
)
70-
71-
url_parsed = urlparse(url)
72-
endpoint_parsed = urlparse(endpoint_url)
73-
if (
74-
url_parsed.netloc != endpoint_parsed.netloc
75-
or url_parsed.scheme
76-
!= endpoint_parsed.scheme
77-
):
78-
error_msg = (
79-
"Endpoint origin does not match "
80-
f"connection origin: {endpoint_url}"
52+
with catch({
53+
Exception: handle_exception,
54+
}):
55+
async with anyio.create_task_group() as tg:
56+
try:
57+
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
58+
async with httpx.AsyncClient(headers=headers) as client:
59+
async with aconnect_sse(
60+
client,
61+
"GET",
62+
url,
63+
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
64+
) as event_source:
65+
event_source.response.raise_for_status()
66+
logger.debug("SSE connection established")
67+
68+
async def sse_reader(
69+
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
70+
):
71+
try:
72+
async for sse in event_source.aiter_sse():
73+
logger.debug(f"Received SSE event: {sse.event}")
74+
match sse.event:
75+
case "endpoint":
76+
endpoint_url = urljoin(url, sse.data)
77+
logger.info(
78+
f"Received endpoint URL: {endpoint_url}"
8179
)
82-
logger.error(error_msg)
83-
raise ValueError(error_msg)
84-
85-
task_status.started(endpoint_url)
8680

87-
case "message":
88-
try:
89-
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
90-
sse.data
91-
)
92-
logger.debug(
93-
f"Received server message: {message}"
81+
url_parsed = urlparse(url)
82+
endpoint_parsed = urlparse(endpoint_url)
83+
if (
84+
url_parsed.netloc
85+
!= endpoint_parsed.netloc
86+
or url_parsed.scheme
87+
!= endpoint_parsed.scheme
88+
):
89+
error_msg = (
90+
"Endpoint origin does not match "
91+
f"connection origin: {endpoint_url}"
92+
)
93+
logger.error(error_msg)
94+
raise ValueError(error_msg)
95+
96+
task_status.started(endpoint_url)
97+
98+
case "message":
99+
try:
100+
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
101+
sse.data
102+
)
103+
logger.debug(
104+
f"Received server message: "
105+
f"{message}"
106+
)
107+
except Exception as exc:
108+
logger.error(
109+
f"Error parsing server message: "
110+
f"{exc}"
111+
)
112+
await read_stream_writer.send(exc)
113+
continue
114+
115+
await read_stream_writer.send(message)
116+
case _:
117+
logger.warning(
118+
f"Unknown SSE event: {sse.event}"
94119
)
95-
except Exception as exc:
96-
logger.error(
97-
f"Error parsing server message: {exc}"
98-
)
99-
await read_stream_writer.send(exc)
100-
continue
101-
102-
await read_stream_writer.send(message)
103-
case _:
104-
logger.warning(
105-
f"Unknown SSE event: {sse.event}"
120+
except Exception as exc:
121+
logger.error(f"Error in sse_reader: {exc}")
122+
raise
123+
finally:
124+
await read_stream_writer.aclose()
125+
126+
async def post_writer(endpoint_url: str):
127+
try:
128+
async with write_stream_reader:
129+
async for message in write_stream_reader:
130+
logger.debug(
131+
f"Sending client message: {message}"
106132
)
107-
except Exception as exc:
108-
logger.error(f"Error in sse_reader: {exc}")
109-
raise
110-
finally:
111-
await read_stream_writer.aclose()
133+
response = await client.post(
134+
endpoint_url,
135+
json=message.model_dump(
136+
by_alias=True,
137+
mode="json",
138+
exclude_none=True,
139+
),
140+
)
141+
response.raise_for_status()
142+
logger.debug(
143+
"Client message sent successfully: "
144+
f"{response.status_code}"
145+
)
146+
except Exception as exc:
147+
logger.error(f"Error in post_writer: {exc}")
148+
finally:
149+
await write_stream.aclose()
150+
151+
endpoint_url = await tg.start(sse_reader)
152+
logger.info(
153+
f"Starting post writer with endpoint URL: {endpoint_url}"
154+
)
155+
tg.start_soon(post_writer, endpoint_url)
112156

113-
async def post_writer(endpoint_url: str):
114157
try:
115-
async with write_stream_reader:
116-
async for message in write_stream_reader:
117-
logger.debug(f"Sending client message: {message}")
118-
response = await client.post(
119-
endpoint_url,
120-
json=message.model_dump(
121-
by_alias=True,
122-
mode="json",
123-
exclude_none=True,
124-
),
125-
)
126-
response.raise_for_status()
127-
logger.debug(
128-
"Client message sent successfully: "
129-
f"{response.status_code}"
130-
)
131-
except Exception as exc:
132-
logger.error(f"Error in post_writer: {exc}")
158+
yield read_stream, write_stream
133159
finally:
134-
await write_stream.aclose()
135-
136-
endpoint_url = await tg.start(sse_reader)
137-
logger.info(
138-
f"Starting post writer with endpoint URL: {endpoint_url}"
139-
)
140-
tg.start_soon(post_writer, endpoint_url)
141-
142-
try:
143-
yield read_stream, write_stream
144-
finally:
145-
tg.cancel_scope.cancel()
146-
except* ValueError as eg:
147-
errors.extend(eg.exceptions)
148-
except* Exception as eg:
149-
errors.extend(eg.exceptions)
150-
finally:
151-
await read_stream_writer.aclose()
152-
await write_stream.aclose()
153-
if errors:
154-
raise Exception("TaskGroup failed with: " + " ".join([str(e) for e in errors]))
160+
tg.cancel_scope.cancel()
161+
finally:
162+
await read_stream_writer.aclose()
163+
await write_stream.aclose()

uv.lock

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)