|
18 | 18 |
|
19 | 19 | from mcp.client.auth import OAuthClientProvider, TokenStorage
|
20 | 20 | from mcp.client.session import ClientSession
|
| 21 | +from mcp.client.sse import sse_client |
21 | 22 | from mcp.client.streamable_http import streamablehttp_client
|
22 | 23 | from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
23 | 24 |
|
@@ -149,8 +150,9 @@ def get_state(self):
|
149 | 150 | class SimpleAuthClient:
|
150 | 151 | """Simple MCP client with auth support."""
|
151 | 152 |
|
152 |
| - def __init__(self, server_url: str): |
| 153 | + def __init__(self, server_url: str, transport_type: str = "streamable_http"): |
153 | 154 | self.server_url = server_url
|
| 155 | + self.transport_type = transport_type |
154 | 156 | self.session: ClientSession | None = None
|
155 | 157 |
|
156 | 158 | async def connect(self):
|
@@ -195,38 +197,48 @@ async def _default_redirect_handler(authorization_url: str) -> None:
|
195 | 197 | callback_handler=callback_handler,
|
196 | 198 | )
|
197 | 199 |
|
198 |
| - # Create streamable HTTP transport with auth handler |
199 |
| - stream_context = streamablehttp_client( |
200 |
| - url=self.server_url, |
201 |
| - auth=oauth_auth, |
202 |
| - timeout=timedelta(seconds=60), |
203 |
| - ) |
204 |
| - |
205 |
| - print( |
206 |
| - "📡 Opening transport connection (HTTPX handles auth automatically)..." |
207 |
| - ) |
208 |
| - async with stream_context as (read_stream, write_stream, get_session_id): |
209 |
| - print("🤝 Initializing MCP session...") |
210 |
| - async with ClientSession(read_stream, write_stream) as session: |
211 |
| - self.session = session |
212 |
| - print("⚡ Starting session initialization...") |
213 |
| - await session.initialize() |
214 |
| - print("✨ Session initialization complete!") |
215 |
| - |
216 |
| - print(f"\n✅ Connected to MCP server at {self.server_url}") |
217 |
| - session_id = get_session_id() |
218 |
| - if session_id: |
219 |
| - print(f"Session ID: {session_id}") |
220 |
| - |
221 |
| - # Run interactive loop |
222 |
| - await self.interactive_loop() |
| 200 | + # Create transport with auth handler based on transport type |
| 201 | + if self.transport_type == "sse": |
| 202 | + print("📡 Opening SSE transport connection with auth...") |
| 203 | + async with sse_client( |
| 204 | + url=self.server_url, |
| 205 | + auth=oauth_auth, |
| 206 | + timeout=60, |
| 207 | + ) as (read_stream, write_stream): |
| 208 | + await self._run_session(read_stream, write_stream, None) |
| 209 | + else: |
| 210 | + print("📡 Opening StreamableHTTP transport connection with auth...") |
| 211 | + async with streamablehttp_client( |
| 212 | + url=self.server_url, |
| 213 | + auth=oauth_auth, |
| 214 | + timeout=timedelta(seconds=60), |
| 215 | + ) as (read_stream, write_stream, get_session_id): |
| 216 | + await self._run_session(read_stream, write_stream, get_session_id) |
223 | 217 |
|
224 | 218 | except Exception as e:
|
225 | 219 | print(f"❌ Failed to connect: {e}")
|
226 | 220 | import traceback
|
227 | 221 |
|
228 | 222 | traceback.print_exc()
|
229 | 223 |
|
| 224 | + async def _run_session(self, read_stream, write_stream, get_session_id): |
| 225 | + """Run the MCP session with the given streams.""" |
| 226 | + print("🤝 Initializing MCP session...") |
| 227 | + async with ClientSession(read_stream, write_stream) as session: |
| 228 | + self.session = session |
| 229 | + print("⚡ Starting session initialization...") |
| 230 | + await session.initialize() |
| 231 | + print("✨ Session initialization complete!") |
| 232 | + |
| 233 | + print(f"\n✅ Connected to MCP server at {self.server_url}") |
| 234 | + if get_session_id: |
| 235 | + session_id = get_session_id() |
| 236 | + if session_id: |
| 237 | + print(f"Session ID: {session_id}") |
| 238 | + |
| 239 | + # Run interactive loop |
| 240 | + await self.interactive_loop() |
| 241 | + |
230 | 242 | async def list_tools(self):
|
231 | 243 | """List available tools from the server."""
|
232 | 244 | if not self.session:
|
@@ -326,13 +338,20 @@ async def main():
|
326 | 338 | """Main entry point."""
|
327 | 339 | # Default server URL - can be overridden with environment variable
|
328 | 340 | # Most MCP streamable HTTP servers use /mcp as the endpoint
|
329 |
| - server_url = os.getenv("MCP_SERVER_URL", "http://localhost:8000/mcp") |
| 341 | + server_url = os.getenv("MCP_SERVER_PORT", 8000) |
| 342 | + transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable_http") |
| 343 | + server_url = ( |
| 344 | + f"http://localhost:{server_url}/mcp" |
| 345 | + if transport_type == "streamable_http" |
| 346 | + else f"http://localhost:{server_url}/sse" |
| 347 | + ) |
330 | 348 |
|
331 | 349 | print("🚀 Simple MCP Auth Client")
|
332 | 350 | print(f"Connecting to: {server_url}")
|
| 351 | + print(f"Transport type: {transport_type}") |
333 | 352 |
|
334 | 353 | # Start connection flow - OAuth will be handled automatically
|
335 |
| - client = SimpleAuthClient(server_url) |
| 354 | + client = SimpleAuthClient(server_url, transport_type) |
336 | 355 | await client.connect()
|
337 | 356 |
|
338 | 357 |
|
|
0 commit comments