Skip to content

Commit 8e0e369

Browse files
committed
Refactor stdio client to modularize transport and parameters
Separated the stdio client logic into dedicated modules for transport and parameter handling. The refactoring reduces complexity by isolating process management into `StdioClientTransport` and moving environment defaulting to a `parameters` module, improving maintainability and readability. Signed-off-by: DanielAvdar <66269169+DanielAvdar@users.noreply.github.com>
1 parent a5b53cf commit 8e0e369

File tree

6 files changed

+1070
-901
lines changed

6 files changed

+1070
-901
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ dev = [
5555
"pytest-xdist>=3.6.1",
5656
"pytest-examples>=0.0.14",
5757
"pytest-pretty>=1.2.0",
58+
"pytest-cov>=6.1.1",
5859
]
5960
docs = [
6061
"mkdocs>=1.6.1",

src/mcp/client/stdio/__init__.py

Lines changed: 11 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -1,222 +1,22 @@
1-
import os
21
import sys
32
from contextlib import asynccontextmanager
4-
from pathlib import Path
5-
from typing import Literal, TextIO
3+
from typing import TextIO
64

7-
import anyio
8-
import anyio.lowlevel
9-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10-
from anyio.streams.text import TextReceiveStream
11-
from pydantic import BaseModel, Field
12-
13-
import mcp.types as types
14-
from mcp.shared.message import SessionMessage
15-
16-
from .win32 import (
17-
create_windows_process,
18-
get_windows_executable_command,
19-
terminate_windows_process,
20-
)
21-
22-
# Environment variables to inherit by default
23-
DEFAULT_INHERITED_ENV_VARS = (
24-
[
25-
"APPDATA",
26-
"HOMEDRIVE",
27-
"HOMEPATH",
28-
"LOCALAPPDATA",
29-
"PATH",
30-
"PROCESSOR_ARCHITECTURE",
31-
"SYSTEMDRIVE",
32-
"SYSTEMROOT",
33-
"TEMP",
34-
"USERNAME",
35-
"USERPROFILE",
36-
]
37-
if sys.platform == "win32"
38-
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
39-
)
40-
41-
42-
def get_default_environment() -> dict[str, str]:
43-
"""
44-
Returns a default environment object including only environment variables deemed
45-
safe to inherit.
46-
"""
47-
env: dict[str, str] = {}
48-
49-
for key in DEFAULT_INHERITED_ENV_VARS:
50-
value = os.environ.get(key)
51-
if value is None:
52-
continue
53-
54-
if value.startswith("()"):
55-
# Skip functions, which are a security risk
56-
continue
57-
58-
env[key] = value
59-
60-
return env
61-
62-
63-
class StdioServerParameters(BaseModel):
64-
command: str
65-
"""The executable to run to start the server."""
66-
67-
args: list[str] = Field(default_factory=list)
68-
"""Command line arguments to pass to the executable."""
69-
70-
env: dict[str, str] | None = None
71-
"""
72-
The environment to use when spawning the process.
73-
74-
If not specified, the result of get_default_environment() will be used.
75-
"""
76-
77-
cwd: str | Path | None = None
78-
"""The working directory to use when spawning the process."""
79-
80-
encoding: str = "utf-8"
81-
"""
82-
The text encoding used when sending/receiving messages to the server
83-
84-
defaults to utf-8
85-
"""
86-
87-
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict"
88-
"""
89-
The text encoding error handler.
90-
91-
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
92-
explanations of possible values
93-
"""
5+
# Import from the new files
6+
from .parameters import StdioServerParameters
7+
from .transport import StdioClientTransport
948

959

9610
@asynccontextmanager
9711
async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr):
9812
"""
99-
Client transport for stdio: this will connect to a server by spawning a
100-
process and communicating with it over stdin/stdout.
13+
Client transport for stdio: connects to a server by spawning a process
14+
and communicating with it over stdin/stdout, managed by StdioClientTransport.
10115
"""
102-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
103-
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
104-
105-
write_stream: MemoryObjectSendStream[SessionMessage]
106-
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
107-
108-
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
109-
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
110-
111-
command = _get_executable_command(server.command)
112-
113-
# Open process with stderr piped for capture
114-
process = await _create_platform_compatible_process(
115-
command=command,
116-
args=server.args,
117-
env=(
118-
{**get_default_environment(), **server.env}
119-
if server.env is not None
120-
else get_default_environment()
121-
),
122-
errlog=errlog,
123-
cwd=server.cwd,
124-
)
125-
126-
async def stdout_reader():
127-
assert process.stdout, "Opened process is missing stdout"
128-
129-
try:
130-
async with read_stream_writer:
131-
buffer = ""
132-
async for chunk in TextReceiveStream(
133-
process.stdout,
134-
encoding=server.encoding,
135-
errors=server.encoding_error_handler,
136-
):
137-
lines = (buffer + chunk).split("\n")
138-
buffer = lines.pop()
139-
140-
for line in lines:
141-
try:
142-
message = types.JSONRPCMessage.model_validate_json(line)
143-
except Exception as exc:
144-
await read_stream_writer.send(exc)
145-
continue
146-
147-
session_message = SessionMessage(message)
148-
await read_stream_writer.send(session_message)
149-
except anyio.ClosedResourceError:
150-
await anyio.lowlevel.checkpoint()
151-
152-
async def stdin_writer():
153-
assert process.stdin, "Opened process is missing stdin"
16+
transport = StdioClientTransport(server_params=server, errlog=errlog)
17+
async with transport as streams:
18+
yield streams
15419

155-
try:
156-
async with write_stream_reader:
157-
async for session_message in write_stream_reader:
158-
json = session_message.message.model_dump_json(
159-
by_alias=True, exclude_none=True
160-
)
161-
await process.stdin.send(
162-
(json + "\n").encode(
163-
encoding=server.encoding,
164-
errors=server.encoding_error_handler,
165-
)
166-
)
167-
except anyio.ClosedResourceError:
168-
await anyio.lowlevel.checkpoint()
169-
170-
async with (
171-
anyio.create_task_group() as tg,
172-
process,
173-
):
174-
tg.start_soon(stdout_reader)
175-
tg.start_soon(stdin_writer)
176-
try:
177-
yield read_stream, write_stream
178-
finally:
179-
# Clean up process to prevent any dangling orphaned processes
180-
if sys.platform == "win32":
181-
await terminate_windows_process(process)
182-
else:
183-
process.terminate()
184-
await read_stream.aclose()
185-
await write_stream.aclose()
186-
187-
188-
def _get_executable_command(command: str) -> str:
189-
"""
190-
Get the correct executable command normalized for the current platform.
191-
192-
Args:
193-
command: Base command (e.g., 'uvx', 'npx')
194-
195-
Returns:
196-
str: Platform-appropriate command
197-
"""
198-
if sys.platform == "win32":
199-
return get_windows_executable_command(command)
200-
else:
201-
return command
202-
203-
204-
async def _create_platform_compatible_process(
205-
command: str,
206-
args: list[str],
207-
env: dict[str, str] | None = None,
208-
errlog: TextIO = sys.stderr,
209-
cwd: Path | str | None = None,
210-
):
211-
"""
212-
Creates a subprocess in a platform-compatible way.
213-
Returns a process handle.
214-
"""
215-
if sys.platform == "win32":
216-
process = await create_windows_process(command, args, env, errlog, cwd)
217-
else:
218-
process = await anyio.open_process(
219-
[command, *args], env=env, stderr=errlog, cwd=cwd
220-
)
22120

222-
return process
21+
# Ensure __all__ or exports are updated if this was a public API change, though
22+
# stdio_client itself remains the primary public entry point from this file.

src/mcp/client/stdio/parameters.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import os
2+
import sys
3+
from pathlib import Path
4+
from typing import Literal
5+
6+
from pydantic import BaseModel, Field
7+
8+
# Environment variables to inherit by default
9+
DEFAULT_INHERITED_ENV_VARS = (
10+
[
11+
"APPDATA",
12+
"HOMEDRIVE",
13+
"HOMEPATH",
14+
"LOCALAPPDATA",
15+
"PATH",
16+
"PROCESSOR_ARCHITECTURE",
17+
"SYSTEMDRIVE",
18+
"SYSTEMROOT",
19+
"TEMP",
20+
"USERNAME",
21+
"USERPROFILE",
22+
]
23+
if sys.platform == "win32"
24+
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
25+
)
26+
27+
28+
def get_default_environment() -> dict[str, str]:
29+
"""Returns a default environment object including only environment variables deemed
30+
safe to inherit.
31+
"""
32+
env: dict[str, str] = {}
33+
34+
for key in DEFAULT_INHERITED_ENV_VARS:
35+
value = os.environ.get(key)
36+
if value is None:
37+
continue
38+
39+
if value.startswith("()"):
40+
# Skip functions, which are a security risk
41+
continue
42+
43+
env[key] = value
44+
45+
return env
46+
47+
48+
class StdioServerParameters(BaseModel):
49+
command: str
50+
"""The executable to run to start the server."""
51+
52+
args: list[str] = Field(default_factory=list)
53+
"""Command line arguments to pass to the executable."""
54+
55+
env: dict[str, str] | None = None
56+
"""
57+
The environment to use when spawning the process.
58+
59+
If not specified, the result of get_default_environment() will be used.
60+
"""
61+
62+
cwd: str | Path | None = None
63+
"""The working directory to use when spawning the process."""
64+
65+
encoding: str = "utf-8"
66+
"""
67+
The text encoding used when sending/receiving messages to the server
68+
69+
defaults to utf-8
70+
"""
71+
72+
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict"
73+
"""
74+
The text encoding error handler.
75+
76+
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
77+
explanations of possible values
78+
"""

0 commit comments

Comments
 (0)