|
1 |
| -import os |
2 | 1 | import sys
|
3 | 2 | from contextlib import asynccontextmanager
|
4 |
| -from pathlib import Path |
5 |
| -from typing import Literal, TextIO |
| 3 | +from typing import TextIO |
6 | 4 |
|
7 |
| -import anyio |
8 |
| -import anyio.lowlevel |
9 |
| -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
10 |
| -from anyio.streams.text import TextReceiveStream |
11 |
| -from pydantic import BaseModel, Field |
12 |
| - |
13 |
| -import mcp.types as types |
14 |
| -from mcp.shared.message import SessionMessage |
15 |
| - |
16 |
| -from .win32 import ( |
17 |
| - create_windows_process, |
18 |
| - get_windows_executable_command, |
19 |
| - terminate_windows_process, |
20 |
| -) |
21 |
| - |
22 |
| -# Environment variables to inherit by default |
23 |
| -DEFAULT_INHERITED_ENV_VARS = ( |
24 |
| - [ |
25 |
| - "APPDATA", |
26 |
| - "HOMEDRIVE", |
27 |
| - "HOMEPATH", |
28 |
| - "LOCALAPPDATA", |
29 |
| - "PATH", |
30 |
| - "PROCESSOR_ARCHITECTURE", |
31 |
| - "SYSTEMDRIVE", |
32 |
| - "SYSTEMROOT", |
33 |
| - "TEMP", |
34 |
| - "USERNAME", |
35 |
| - "USERPROFILE", |
36 |
| - ] |
37 |
| - if sys.platform == "win32" |
38 |
| - else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"] |
39 |
| -) |
40 |
| - |
41 |
| - |
42 |
| -def get_default_environment() -> dict[str, str]: |
43 |
| - """ |
44 |
| - Returns a default environment object including only environment variables deemed |
45 |
| - safe to inherit. |
46 |
| - """ |
47 |
| - env: dict[str, str] = {} |
48 |
| - |
49 |
| - for key in DEFAULT_INHERITED_ENV_VARS: |
50 |
| - value = os.environ.get(key) |
51 |
| - if value is None: |
52 |
| - continue |
53 |
| - |
54 |
| - if value.startswith("()"): |
55 |
| - # Skip functions, which are a security risk |
56 |
| - continue |
57 |
| - |
58 |
| - env[key] = value |
59 |
| - |
60 |
| - return env |
61 |
| - |
62 |
| - |
63 |
| -class StdioServerParameters(BaseModel): |
64 |
| - command: str |
65 |
| - """The executable to run to start the server.""" |
66 |
| - |
67 |
| - args: list[str] = Field(default_factory=list) |
68 |
| - """Command line arguments to pass to the executable.""" |
69 |
| - |
70 |
| - env: dict[str, str] | None = None |
71 |
| - """ |
72 |
| - The environment to use when spawning the process. |
73 |
| -
|
74 |
| - If not specified, the result of get_default_environment() will be used. |
75 |
| - """ |
76 |
| - |
77 |
| - cwd: str | Path | None = None |
78 |
| - """The working directory to use when spawning the process.""" |
79 |
| - |
80 |
| - encoding: str = "utf-8" |
81 |
| - """ |
82 |
| - The text encoding used when sending/receiving messages to the server |
83 |
| -
|
84 |
| - defaults to utf-8 |
85 |
| - """ |
86 |
| - |
87 |
| - encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict" |
88 |
| - """ |
89 |
| - The text encoding error handler. |
90 |
| -
|
91 |
| - See https://docs.python.org/3/library/codecs.html#codec-base-classes for |
92 |
| - explanations of possible values |
93 |
| - """ |
| 5 | +# Import from the new files |
| 6 | +from .parameters import StdioServerParameters |
| 7 | +from .transport import StdioClientTransport |
94 | 8 |
|
95 | 9 |
|
96 | 10 | @asynccontextmanager
|
97 | 11 | async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr):
|
98 | 12 | """
|
99 |
| - Client transport for stdio: this will connect to a server by spawning a |
100 |
| - process and communicating with it over stdin/stdout. |
| 13 | + Client transport for stdio: connects to a server by spawning a process |
| 14 | + and communicating with it over stdin/stdout, managed by StdioClientTransport. |
101 | 15 | """
|
102 |
| - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] |
103 |
| - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] |
104 |
| - |
105 |
| - write_stream: MemoryObjectSendStream[SessionMessage] |
106 |
| - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] |
107 |
| - |
108 |
| - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) |
109 |
| - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) |
110 |
| - |
111 |
| - command = _get_executable_command(server.command) |
112 |
| - |
113 |
| - # Open process with stderr piped for capture |
114 |
| - process = await _create_platform_compatible_process( |
115 |
| - command=command, |
116 |
| - args=server.args, |
117 |
| - env=( |
118 |
| - {**get_default_environment(), **server.env} |
119 |
| - if server.env is not None |
120 |
| - else get_default_environment() |
121 |
| - ), |
122 |
| - errlog=errlog, |
123 |
| - cwd=server.cwd, |
124 |
| - ) |
125 |
| - |
126 |
| - async def stdout_reader(): |
127 |
| - assert process.stdout, "Opened process is missing stdout" |
128 |
| - |
129 |
| - try: |
130 |
| - async with read_stream_writer: |
131 |
| - buffer = "" |
132 |
| - async for chunk in TextReceiveStream( |
133 |
| - process.stdout, |
134 |
| - encoding=server.encoding, |
135 |
| - errors=server.encoding_error_handler, |
136 |
| - ): |
137 |
| - lines = (buffer + chunk).split("\n") |
138 |
| - buffer = lines.pop() |
139 |
| - |
140 |
| - for line in lines: |
141 |
| - try: |
142 |
| - message = types.JSONRPCMessage.model_validate_json(line) |
143 |
| - except Exception as exc: |
144 |
| - await read_stream_writer.send(exc) |
145 |
| - continue |
146 |
| - |
147 |
| - session_message = SessionMessage(message) |
148 |
| - await read_stream_writer.send(session_message) |
149 |
| - except anyio.ClosedResourceError: |
150 |
| - await anyio.lowlevel.checkpoint() |
151 |
| - |
152 |
| - async def stdin_writer(): |
153 |
| - assert process.stdin, "Opened process is missing stdin" |
| 16 | + transport = StdioClientTransport(server_params=server, errlog=errlog) |
| 17 | + async with transport as streams: |
| 18 | + yield streams |
154 | 19 |
|
155 |
| - try: |
156 |
| - async with write_stream_reader: |
157 |
| - async for session_message in write_stream_reader: |
158 |
| - json = session_message.message.model_dump_json( |
159 |
| - by_alias=True, exclude_none=True |
160 |
| - ) |
161 |
| - await process.stdin.send( |
162 |
| - (json + "\n").encode( |
163 |
| - encoding=server.encoding, |
164 |
| - errors=server.encoding_error_handler, |
165 |
| - ) |
166 |
| - ) |
167 |
| - except anyio.ClosedResourceError: |
168 |
| - await anyio.lowlevel.checkpoint() |
169 |
| - |
170 |
| - async with ( |
171 |
| - anyio.create_task_group() as tg, |
172 |
| - process, |
173 |
| - ): |
174 |
| - tg.start_soon(stdout_reader) |
175 |
| - tg.start_soon(stdin_writer) |
176 |
| - try: |
177 |
| - yield read_stream, write_stream |
178 |
| - finally: |
179 |
| - # Clean up process to prevent any dangling orphaned processes |
180 |
| - if sys.platform == "win32": |
181 |
| - await terminate_windows_process(process) |
182 |
| - else: |
183 |
| - process.terminate() |
184 |
| - await read_stream.aclose() |
185 |
| - await write_stream.aclose() |
186 |
| - |
187 |
| - |
188 |
| -def _get_executable_command(command: str) -> str: |
189 |
| - """ |
190 |
| - Get the correct executable command normalized for the current platform. |
191 |
| -
|
192 |
| - Args: |
193 |
| - command: Base command (e.g., 'uvx', 'npx') |
194 |
| -
|
195 |
| - Returns: |
196 |
| - str: Platform-appropriate command |
197 |
| - """ |
198 |
| - if sys.platform == "win32": |
199 |
| - return get_windows_executable_command(command) |
200 |
| - else: |
201 |
| - return command |
202 |
| - |
203 |
| - |
204 |
| -async def _create_platform_compatible_process( |
205 |
| - command: str, |
206 |
| - args: list[str], |
207 |
| - env: dict[str, str] | None = None, |
208 |
| - errlog: TextIO = sys.stderr, |
209 |
| - cwd: Path | str | None = None, |
210 |
| -): |
211 |
| - """ |
212 |
| - Creates a subprocess in a platform-compatible way. |
213 |
| - Returns a process handle. |
214 |
| - """ |
215 |
| - if sys.platform == "win32": |
216 |
| - process = await create_windows_process(command, args, env, errlog, cwd) |
217 |
| - else: |
218 |
| - process = await anyio.open_process( |
219 |
| - [command, *args], env=env, stderr=errlog, cwd=cwd |
220 |
| - ) |
221 | 20 |
|
222 |
| - return process |
| 21 | +# Ensure __all__ or exports are updated if this was a public API change, though |
| 22 | +# stdio_client itself remains the primary public entry point from this file. |
0 commit comments