Skip to content

Commit 1cb7407

Browse files
authored
Properly infer prefix for SSE messages (#659)
1 parent 5289d06 commit 1cb7407

File tree

2 files changed

+87
-3
lines changed

2 files changed

+87
-3
lines changed

src/mcp/server/sse.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,37 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
100100
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
101101

102102
session_id = uuid4()
103-
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
104103
self._read_stream_writers[session_id] = read_stream_writer
105104
logger.debug(f"Created new session with ID: {session_id}")
106105

106+
# Determine the full path for the message endpoint to be sent to the client.
107+
# scope['root_path'] is the prefix where the current Starlette app
108+
# instance is mounted.
109+
# e.g., "" if top-level, or "/api_prefix" if mounted under "/api_prefix".
110+
root_path = scope.get("root_path", "")
111+
112+
# self._endpoint is the path *within* this app, e.g., "/messages".
113+
# Concatenating them gives the full absolute path from the server root.
114+
# e.g., "" + "/messages" -> "/messages"
115+
# e.g., "/api_prefix" + "/messages" -> "/api_prefix/messages"
116+
full_message_path_for_client = root_path.rstrip("/") + self._endpoint
117+
118+
# This is the URI (path + query) the client will use to POST messages.
119+
client_post_uri_data = (
120+
f"{quote(full_message_path_for_client)}?session_id={session_id.hex}"
121+
)
122+
107123
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
108124
dict[str, Any]
109125
](0)
110126

111127
async def sse_writer():
112128
logger.debug("Starting SSE writer")
113129
async with sse_stream_writer, write_stream_reader:
114-
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
115-
logger.debug(f"Sent endpoint event: {session_uri}")
130+
await sse_stream_writer.send(
131+
{"event": "endpoint", "data": client_post_uri_data}
132+
)
133+
logger.debug(f"Sent endpoint event: {client_post_uri_data}")
116134

117135
async for session_message in write_stream_reader:
118136
logger.debug(f"Sending message via SSE: {session_message}")

tests/shared/test_sse.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,69 @@ async def test_sse_client_timeout(
252252
return
253253

254254
pytest.fail("the client should have timed out and returned an error already")
255+
256+
257+
def run_mounted_server(server_port: int) -> None:
258+
app = make_server_app()
259+
main_app = Starlette(routes=[Mount("/mounted_app", app=app)])
260+
server = uvicorn.Server(
261+
config=uvicorn.Config(
262+
app=main_app, host="127.0.0.1", port=server_port, log_level="error"
263+
)
264+
)
265+
print(f"starting server on {server_port}")
266+
server.run()
267+
268+
# Give server time to start
269+
while not server.started:
270+
print("waiting for server to start")
271+
time.sleep(0.5)
272+
273+
274+
@pytest.fixture()
275+
def mounted_server(server_port: int) -> Generator[None, None, None]:
276+
proc = multiprocessing.Process(
277+
target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True
278+
)
279+
print("starting process")
280+
proc.start()
281+
282+
# Wait for server to be running
283+
max_attempts = 20
284+
attempt = 0
285+
print("waiting for server to start")
286+
while attempt < max_attempts:
287+
try:
288+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
289+
s.connect(("127.0.0.1", server_port))
290+
break
291+
except ConnectionRefusedError:
292+
time.sleep(0.1)
293+
attempt += 1
294+
else:
295+
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
296+
297+
yield
298+
299+
print("killing server")
300+
# Signal the server to stop
301+
proc.kill()
302+
proc.join(timeout=2)
303+
if proc.is_alive():
304+
print("server process failed to terminate")
305+
306+
307+
@pytest.mark.anyio
308+
async def test_sse_client_basic_connection_mounted_app(
309+
mounted_server: None, server_url: str
310+
) -> None:
311+
async with sse_client(server_url + "/mounted_app/sse") as streams:
312+
async with ClientSession(*streams) as session:
313+
# Test initialization
314+
result = await session.initialize()
315+
assert isinstance(result, InitializeResult)
316+
assert result.serverInfo.name == SERVER_NAME
317+
318+
# Test ping
319+
ping_result = await session.send_ping()
320+
assert isinstance(ping_result, EmptyResult)

0 commit comments

Comments
 (0)