Skip to content

Commit 4bed186

Browse files
committed
test
1 parent 15dd3d0 commit 4bed186

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

src/mcp/server/sse.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,20 @@ async def sse_writer():
129129
logger.debug("Starting SSE response task")
130130
tg.start_soon(response, scope, receive, send)
131131

132-
logger.debug("Yielding read and write streams")
133-
yield (read_stream, write_stream, response)
132+
try:
133+
logger.debug("Yielding read and write streams")
134+
yield (read_stream, write_stream, response)
135+
finally:
136+
# Cleanup when connection closes
137+
logger.debug(f"Cleaning up SSE session {session_id}")
138+
try:
139+
# Remove session from tracking dictionary
140+
if session_id in self._read_stream_writers:
141+
del self._read_stream_writers[session_id]
142+
# Cancel any remaining tasks in the task group
143+
tg.cancel_scope.cancel()
144+
except Exception as e:
145+
logger.error(f"Error during SSE cleanup: {e}")
134146

135147
async def handle_post_message(
136148
self, scope: Scope, receive: Receive, send: Send

tests/shared/test_sse.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,19 @@ def server(server_port: int) -> Generator[None, None, None]:
142142

143143
yield
144144

145-
print("killing server")
146-
# Signal the server to stop
147-
proc.kill()
148-
proc.join(timeout=2)
145+
print("shutting down server gracefully")
146+
# Try graceful shutdown first
147+
proc.terminate()
148+
try:
149+
proc.join(timeout=5)
150+
except Exception:
151+
print("Graceful shutdown failed, forcing kill")
152+
proc.kill()
153+
proc.join(timeout=2)
154+
149155
if proc.is_alive():
150156
print("server process failed to terminate")
157+
proc.kill() # Force kill as last resort
151158

152159

153160

@@ -180,9 +187,7 @@ async def test_raw_sse_connection(server, server_url) -> None:
180187
pytest.fail(f"{e}")
181188

182189
@pytest.mark.anyio
183-
@pytest.mark.skip(
184-
"fails in CI, but works locally. Need to investigate why."
185-
)
190+
186191
async def test_sse_client_basic_connection(server: None, server_url: str) -> None:
187192
async with sse_client(server_url + "/sse") as streams:
188193
async with ClientSession(*streams) as session:

0 commit comments

Comments
 (0)