Skip to content

Commit 96baf33

Browse files
authored
Merge pull request #2 from sebin1213/main
Fix generate accurate session_uri for nested SSE paths
2 parents 8e76e8a + 7cece5e commit 96baf33

File tree

2 files changed

+9
-14
lines changed

2 files changed

+9
-14
lines changed

src/mcp/client/sse.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import re
32
from contextlib import asynccontextmanager
43
from typing import Any
54
from urllib.parse import urljoin, urlparse
@@ -62,21 +61,12 @@ async def sse_reader(
6261
logger.debug(f"Received SSE event: {sse.event}")
6362
match sse.event:
6463
case "endpoint":
65-
url_parsed = urlparse(url)
66-
67-
base_path = re.search(
68-
r"https?://[^/]+/(.+?)(?:/mcp)?/sse$", url
69-
)
70-
base_path = (
71-
base_path.group(1) if base_path else ""
72-
)
73-
endpoint_url = urljoin(
74-
url_parsed.scheme + "://" + url_parsed.netloc, # noqa: E501
75-
base_path + sse.data
76-
)
64+
endpoint_url = urljoin(url, sse.data)
7765
logger.info(
7866
f"Received endpoint URL: {endpoint_url}"
7967
)
68+
69+
url_parsed = urlparse(url)
8070

8171
endpoint_parsed = urlparse(endpoint_url)
8272
if (

src/mcp/server/sse.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ async def handle_sse(request):
3737
from urllib.parse import quote
3838
from uuid import UUID, uuid4
3939

40+
import re
4041
import anyio
4142
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4243
from pydantic import ValidationError
@@ -95,7 +96,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
9596
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
9697

9798
session_id = uuid4()
98-
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
99+
request_path = scope["path"]
100+
match = re.match(r"^/([^/]+(?:/mcp)?)/sse$", request_path)
101+
mount_prefix = match.group(1) if match else ""
102+
session_uri = f"/{quote(mount_prefix)}{quote(self._endpoint)}?session_id={session_id.hex}"
103+
99104
self._read_stream_writers[session_id] = read_stream_writer
100105
logger.debug(f"Created new session with ID: {session_id}")
101106

0 commit comments

Comments
 (0)