diff --git a/client/src/App.tsx b/client/src/App.tsx index 32bdcf35..c32f4840 100644 --- a/client/src/App.tsx +++ b/client/src/App.tsx @@ -17,6 +17,9 @@ import { Tool, LoggingLevel, } from "@modelcontextprotocol/sdk/types.js"; +import { OAuthTokensSchema } from "@modelcontextprotocol/sdk/shared/auth.js"; +import { SESSION_KEYS, getServerSpecificKey } from "./lib/constants"; +import { AuthDebuggerState } from "./lib/auth-types"; import React, { Suspense, useCallback, @@ -28,18 +31,21 @@ import { useConnection } from "./lib/hooks/useConnection"; import { useDraggablePane } from "./lib/hooks/useDraggablePane"; import { StdErrNotification } from "./lib/notificationTypes"; -import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; +import { Button } from "@/components/ui/button"; import { Bell, Files, FolderTree, Hammer, Hash, + Key, MessageSquare, } from "lucide-react"; import { z } from "zod"; import "./App.css"; +import AuthDebugger from "./components/AuthDebugger"; import ConsoleTab from "./components/ConsoleTab"; import HistoryAndNotifications from "./components/History"; import PingTab from "./components/PingTab"; @@ -111,6 +117,27 @@ const App = () => { } > >([]); + const [isAuthDebuggerVisible, setIsAuthDebuggerVisible] = useState(false); + + // Auth debugger state + const [authState, setAuthState] = useState({ + isInitiatingAuth: false, + oauthTokens: null, + loading: true, + oauthStep: "metadata_discovery", + oauthMetadata: null, + oauthClientInfo: null, + authorizationUrl: null, + authorizationCode: "", + latestError: null, + statusMessage: null, + validationError: null, + }); + + // Helper function to update specific auth state properties + const updateAuthState = (updates: Partial) => { + setAuthState((prev) => ({ ...prev, ...updates })); + }; const nextRequestId = useRef(0); const rootsRef = useRef([]); @@ -208,11 +235,64 @@ const App = () => { (serverUrl: string) => { setSseUrl(serverUrl); setTransportType("sse"); + setIsAuthDebuggerVisible(false); void connectMcpServer(); }, [connectMcpServer], ); + // Update OAuth debug state during debug callback + const onOAuthDebugConnect = useCallback( + ({ + authorizationCode, + errorMsg, + }: { + authorizationCode?: string; + errorMsg?: string; + }) => { + setIsAuthDebuggerVisible(true); + if (authorizationCode) { + updateAuthState({ + authorizationCode, + oauthStep: "token_request", + }); + } + if (errorMsg) { + updateAuthState({ + latestError: new Error(errorMsg), + }); + } + }, + [], + ); + + // Load OAuth tokens when sseUrl changes + useEffect(() => { + const loadOAuthTokens = async () => { + try { + if (sseUrl) { + const key = getServerSpecificKey(SESSION_KEYS.TOKENS, sseUrl); + const tokens = sessionStorage.getItem(key); + if (tokens) { + const parsedTokens = await OAuthTokensSchema.parseAsync( + JSON.parse(tokens), + ); + updateAuthState({ + oauthTokens: parsedTokens, + oauthStep: "complete", + }); + } + } + } catch (error) { + console.error("Error loading OAuth tokens:", error); + } finally { + updateAuthState({ loading: false }); + } + }; + + loadOAuthTokens(); + }, [sseUrl]); + useEffect(() => { fetch(`${getMCPProxyAddress(config)}/config`) .then((response) => response.json()) @@ -446,6 +526,19 @@ const App = () => { setStdErrNotifications([]); }; + // Helper component for rendering the AuthDebugger + const AuthDebuggerWrapper = () => ( + + setIsAuthDebuggerVisible(false)} + authState={authState} + updateAuthState={updateAuthState} + /> + + ); + + // Helper function to render OAuth callback components if (window.location.pathname === "/oauth/callback") { const OAuthCallback = React.lazy( () => import("./components/OAuthCallback"), @@ -457,6 +550,17 @@ const App = () => { ); } + if (window.location.pathname === "/oauth/callback/debug") { + const OAuthDebugCallback = React.lazy( + () => import("./components/OAuthDebugCallback"), + ); + return ( + Loading...}> + + + ); + } + return (
{ Roots + + + Auth +
@@ -689,15 +797,36 @@ const App = () => { setRoots={setRoots} onRootsChange={handleRootsChange} /> + )}
+ ) : isAuthDebuggerVisible ? ( + (window.location.hash = value)} + > + + ) : ( -
+

Connect to an MCP server to start inspecting

+
+

+ Need to configure authentication? +

+ +
)}
diff --git a/client/src/components/AuthDebugger.tsx b/client/src/components/AuthDebugger.tsx new file mode 100644 index 00000000..fa863373 --- /dev/null +++ b/client/src/components/AuthDebugger.tsx @@ -0,0 +1,260 @@ +import { useCallback, useMemo } from "react"; +import { Button } from "@/components/ui/button"; +import { DebugInspectorOAuthClientProvider } from "../lib/auth"; +import { auth } from "@modelcontextprotocol/sdk/client/auth.js"; +import { AlertCircle } from "lucide-react"; +import { AuthDebuggerState } from "../lib/auth-types"; +import { OAuthFlowProgress } from "./OAuthFlowProgress"; +import { OAuthStateMachine } from "../lib/oauth-state-machine"; + +export interface AuthDebuggerProps { + serverUrl: string; + onBack: () => void; + authState: AuthDebuggerState; + updateAuthState: (updates: Partial) => void; +} + +interface StatusMessageProps { + message: { type: "error" | "success" | "info"; message: string }; +} + +const StatusMessage = ({ message }: StatusMessageProps) => { + let bgColor: string; + let textColor: string; + let borderColor: string; + + switch (message.type) { + case "error": + bgColor = "bg-red-50"; + textColor = "text-red-700"; + borderColor = "border-red-200"; + break; + case "success": + bgColor = "bg-green-50"; + textColor = "text-green-700"; + borderColor = "border-green-200"; + break; + case "info": + default: + bgColor = "bg-blue-50"; + textColor = "text-blue-700"; + borderColor = "border-blue-200"; + break; + } + + return ( +
+
+ +

{message.message}

+
+
+ ); +}; + +const AuthDebugger = ({ + serverUrl: serverUrl, + onBack, + authState, + updateAuthState, +}: AuthDebuggerProps) => { + const startOAuthFlow = useCallback(() => { + if (!serverUrl) { + updateAuthState({ + statusMessage: { + type: "error", + message: + "Please enter a server URL in the sidebar before authenticating", + }, + }); + return; + } + + updateAuthState({ + oauthStep: "metadata_discovery", + authorizationUrl: null, + statusMessage: null, + latestError: null, + }); + }, [serverUrl, updateAuthState]); + + const stateMachine = useMemo( + () => new OAuthStateMachine(serverUrl, updateAuthState), + [serverUrl, updateAuthState], + ); + + const proceedToNextStep = useCallback(async () => { + if (!serverUrl) return; + + try { + updateAuthState({ + isInitiatingAuth: true, + statusMessage: null, + latestError: null, + }); + + await stateMachine.executeStep(authState); + } catch (error) { + console.error("OAuth flow error:", error); + updateAuthState({ + latestError: error instanceof Error ? error : new Error(String(error)), + }); + } finally { + updateAuthState({ isInitiatingAuth: false }); + } + }, [serverUrl, authState, updateAuthState, stateMachine]); + + const handleQuickOAuth = useCallback(async () => { + if (!serverUrl) { + updateAuthState({ + statusMessage: { + type: "error", + message: + "Please enter a server URL in the sidebar before authenticating", + }, + }); + return; + } + + updateAuthState({ isInitiatingAuth: true, statusMessage: null }); + try { + const serverAuthProvider = new DebugInspectorOAuthClientProvider( + serverUrl, + ); + await auth(serverAuthProvider, { serverUrl: serverUrl }); + updateAuthState({ + statusMessage: { + type: "info", + message: "Starting OAuth authentication process...", + }, + }); + } catch (error) { + console.error("OAuth initialization error:", error); + updateAuthState({ + statusMessage: { + type: "error", + message: `Failed to start OAuth flow: ${error instanceof Error ? error.message : String(error)}`, + }, + }); + } finally { + updateAuthState({ isInitiatingAuth: false }); + } + }, [serverUrl, updateAuthState]); + + const handleClearOAuth = useCallback(() => { + if (serverUrl) { + const serverAuthProvider = new DebugInspectorOAuthClientProvider( + serverUrl, + ); + serverAuthProvider.clear(); + updateAuthState({ + oauthTokens: null, + oauthStep: "metadata_discovery", + latestError: null, + oauthClientInfo: null, + authorizationCode: "", + validationError: null, + oauthMetadata: null, + statusMessage: { + type: "success", + message: "OAuth tokens cleared successfully", + }, + }); + + // Clear success message after 3 seconds + setTimeout(() => { + updateAuthState({ statusMessage: null }); + }, 3000); + } + }, [serverUrl, updateAuthState]); + + return ( +
+
+

Authentication Settings

+ +
+ +
+
+
+

+ Configure authentication settings for your MCP server connection. +

+ +
+

OAuth Authentication

+

+ Use OAuth to securely authenticate with the MCP server. +

+ + {authState.statusMessage && ( + + )} + + {authState.loading ? ( +

Loading authentication status...

+ ) : ( +
+ {authState.oauthTokens && ( +
+

Access Token:

+
+ {authState.oauthTokens.access_token.substring(0, 25)}... +
+
+ )} + +
+ + + + + +
+ +

+ Choose "Guided" for step-by-step instructions or "Quick" for + the standard automatic flow. +

+
+ )} +
+ + +
+
+
+
+ ); +}; + +export default AuthDebugger; diff --git a/client/src/components/OAuthDebugCallback.tsx b/client/src/components/OAuthDebugCallback.tsx new file mode 100644 index 00000000..88d931c0 --- /dev/null +++ b/client/src/components/OAuthDebugCallback.tsx @@ -0,0 +1,92 @@ +import { useEffect } from "react"; +import { SESSION_KEYS } from "../lib/constants"; +import { + generateOAuthErrorDescription, + parseOAuthCallbackParams, +} from "@/utils/oauthUtils.ts"; + +interface OAuthCallbackProps { + onConnect: ({ + authorizationCode, + errorMsg, + }: { + authorizationCode?: string; + errorMsg?: string; + }) => void; +} + +const OAuthDebugCallback = ({ onConnect }: OAuthCallbackProps) => { + useEffect(() => { + let isProcessed = false; + + const handleCallback = async () => { + // Skip if we've already processed this callback + if (isProcessed) { + return; + } + isProcessed = true; + + const params = parseOAuthCallbackParams(window.location.search); + if (!params.successful) { + const errorMsg = generateOAuthErrorDescription(params); + onConnect({ errorMsg }); + return; + } + + const serverUrl = sessionStorage.getItem(SESSION_KEYS.SERVER_URL); + + // ServerURL isn't set, this can happen if we've opened the + // authentication request in a new tab, so we don't have the same + // session storage + if (!serverUrl) { + // If there's no server URL, we're likely in a new tab + // Just display the code for manual copying + return; + } + + if (!params.code) { + onConnect({ errorMsg: "Missing authorization code" }); + return; + } + + // Instead of storing in sessionStorage, pass the code directly + // to the auth state manager through onConnect + onConnect({ authorizationCode: params.code }); + }; + + handleCallback().finally(() => { + // Only redirect if we have the URL set, otherwise assume this was + // in a new tab + if (sessionStorage.getItem(SESSION_KEYS.SERVER_URL)) { + window.history.replaceState({}, document.title, "/"); + } + }); + + return () => { + isProcessed = true; + }; + }, [onConnect]); + + const callbackParams = parseOAuthCallbackParams(window.location.search); + + return ( +
+
+

+ Please copy this authorization code and return to the Auth Debugger: +

+ + {callbackParams.successful && "code" in callbackParams + ? callbackParams.code + : `No code found: ${callbackParams.error}, ${callbackParams.error_description}`} + +

+ Close this tab and paste the code in the OAuth flow to complete + authentication. +

+
+
+ ); +}; + +export default OAuthDebugCallback; diff --git a/client/src/components/OAuthFlowProgress.tsx b/client/src/components/OAuthFlowProgress.tsx new file mode 100644 index 00000000..f604fc73 --- /dev/null +++ b/client/src/components/OAuthFlowProgress.tsx @@ -0,0 +1,259 @@ +import { AuthDebuggerState, OAuthStep } from "@/lib/auth-types"; +import { CheckCircle2, Circle, ExternalLink } from "lucide-react"; +import { Button } from "./ui/button"; +import { DebugInspectorOAuthClientProvider } from "@/lib/auth"; + +interface OAuthStepProps { + label: string; + isComplete: boolean; + isCurrent: boolean; + error?: Error | null; + children?: React.ReactNode; +} + +const OAuthStepDetails = ({ + label, + isComplete, + isCurrent, + error, + children, +}: OAuthStepProps) => { + return ( +
+
+ {isComplete ? ( + + ) : ( + + )} + {label} +
+ + {/* Show children if current step or complete and children exist */} + {(isCurrent || isComplete) && children && ( +
{children}
+ )} + + {/* Display error if current step and an error exists */} + {isCurrent && error && ( +
+

Error:

+

{error.message}

+
+ )} +
+ ); +}; + +interface OAuthFlowProgressProps { + serverUrl: string; + authState: AuthDebuggerState; + updateAuthState: (updates: Partial) => void; + proceedToNextStep: () => Promise; +} + +export const OAuthFlowProgress = ({ + serverUrl, + authState, + updateAuthState, + proceedToNextStep, +}: OAuthFlowProgressProps) => { + const provider = new DebugInspectorOAuthClientProvider(serverUrl); + + const steps: Array = [ + "metadata_discovery", + "client_registration", + "authorization_redirect", + "authorization_code", + "token_request", + "complete", + ]; + const currentStepIdx = steps.findIndex((s) => s === authState.oauthStep); + + // Helper to get step props + const getStepProps = (stepName: OAuthStep) => ({ + isComplete: + currentStepIdx > steps.indexOf(stepName) || + currentStepIdx === steps.length - 1, // last step is "complete" + isCurrent: authState.oauthStep === stepName, + error: authState.oauthStep === stepName ? authState.latestError : null, + }); + + return ( +
+

OAuth Flow Progress

+

+ Follow these steps to complete OAuth authentication with the server. +

+ +
+ + {provider.getServerMetadata() && ( +
+ + Retrieved OAuth Metadata from {serverUrl} + /.well-known/oauth-authorization-server + +
+                {JSON.stringify(provider.getServerMetadata(), null, 2)}
+              
+
+ )} +
+ + + {authState.oauthClientInfo && ( +
+ + Registered Client Information + +
+                {JSON.stringify(authState.oauthClientInfo, null, 2)}
+              
+
+ )} +
+ + + {authState.authorizationUrl && ( +
+

Authorization URL:

+
+

+ {authState.authorizationUrl} +

+ + + +
+

+ Click the link to authorize in your browser. After + authorization, you'll be redirected back to continue the flow. +

+
+ )} +
+ + +
+ +
+ { + updateAuthState({ + authorizationCode: e.target.value, + validationError: null, + }); + }} + placeholder="Enter the code from the authorization server" + className={`flex h-9 w-full rounded-md border bg-background px-3 py-2 text-sm ring-offset-background file:border-0 file:bg-transparent file:text-sm file:font-medium placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 ${ + authState.validationError ? "border-red-500" : "border-input" + }`} + /> +
+ {authState.validationError && ( +

+ {authState.validationError} +

+ )} +

+ Once you've completed authorization in the link, paste the code + here. +

+
+
+ + + {authState.oauthMetadata && ( +
+ + Token Request Details + +
+

Token Endpoint:

+ + {authState.oauthMetadata.token_endpoint} + +
+
+ )} +
+ + + {authState.oauthTokens && ( +
+ + Access Tokens + +

+ Authentication successful! You can now use the authenticated + connection. These tokens will be used automatically for server + requests. +

+
+                {JSON.stringify(authState.oauthTokens, null, 2)}
+              
+
+ )} +
+
+ +
+ {authState.oauthStep !== "complete" && ( + <> + + + )} + + {authState.oauthStep === "authorization_redirect" && + authState.authorizationUrl && ( + + )} +
+
+ ); +}; diff --git a/client/src/components/__tests__/AuthDebugger.test.tsx b/client/src/components/__tests__/AuthDebugger.test.tsx new file mode 100644 index 00000000..469c1ba1 --- /dev/null +++ b/client/src/components/__tests__/AuthDebugger.test.tsx @@ -0,0 +1,382 @@ +import { + render, + screen, + fireEvent, + waitFor, + act, +} from "@testing-library/react"; +import "@testing-library/jest-dom"; +import { describe, it, beforeEach, jest } from "@jest/globals"; +import AuthDebugger, { AuthDebuggerProps } from "../AuthDebugger"; +import { TooltipProvider } from "@/components/ui/tooltip"; +import { SESSION_KEYS } from "@/lib/constants"; + +const mockOAuthTokens = { + access_token: "test_access_token", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "test_refresh_token", + scope: "test_scope", +}; + +const mockOAuthMetadata = { + issuer: "https://oauth.example.com", + authorization_endpoint: "https://oauth.example.com/authorize", + token_endpoint: "https://oauth.example.com/token", + response_types_supported: ["code"], + grant_types_supported: ["authorization_code"], +}; + +const mockOAuthClientInfo = { + client_id: "test_client_id", + client_secret: "test_client_secret", + redirect_uris: ["http://localhost:3000/oauth/callback/debug"], +}; + +// Mock MCP SDK functions - must be before imports +jest.mock("@modelcontextprotocol/sdk/client/auth.js", () => ({ + auth: jest.fn(), + discoverOAuthMetadata: jest.fn(), + registerClient: jest.fn(), + startAuthorization: jest.fn(), + exchangeAuthorization: jest.fn(), +})); + +// Import the functions to get their types +import { + discoverOAuthMetadata, + registerClient, + startAuthorization, + exchangeAuthorization, +} from "@modelcontextprotocol/sdk/client/auth.js"; +import { OAuthMetadata } from "@modelcontextprotocol/sdk/shared/auth.js"; + +// Type the mocked functions properly +const mockDiscoverOAuthMetadata = discoverOAuthMetadata as jest.MockedFunction< + typeof discoverOAuthMetadata +>; +const mockRegisterClient = registerClient as jest.MockedFunction< + typeof registerClient +>; +const mockStartAuthorization = startAuthorization as jest.MockedFunction< + typeof startAuthorization +>; +const mockExchangeAuthorization = exchangeAuthorization as jest.MockedFunction< + typeof exchangeAuthorization +>; + +const sessionStorageMock = { + getItem: jest.fn(), + setItem: jest.fn(), + removeItem: jest.fn(), + clear: jest.fn(), +}; +Object.defineProperty(window, "sessionStorage", { + value: sessionStorageMock, +}); + +Object.defineProperty(window, "location", { + value: { + origin: "http://localhost:3000", + }, +}); + +describe("AuthDebugger", () => { + const defaultAuthState = { + isInitiatingAuth: false, + oauthTokens: null, + loading: false, + oauthStep: "metadata_discovery" as const, + oauthMetadata: null, + oauthClientInfo: null, + authorizationUrl: null, + authorizationCode: "", + latestError: null, + statusMessage: null, + validationError: null, + }; + + const defaultProps = { + serverUrl: "https://example.com", + onBack: jest.fn(), + authState: defaultAuthState, + updateAuthState: jest.fn(), + }; + + beforeEach(() => { + jest.clearAllMocks(); + sessionStorageMock.getItem.mockReturnValue(null); + + mockDiscoverOAuthMetadata.mockResolvedValue(mockOAuthMetadata); + mockRegisterClient.mockResolvedValue(mockOAuthClientInfo); + mockStartAuthorization.mockImplementation(async (_sseUrl, options) => { + const authUrl = new URL("https://oauth.example.com/authorize"); + + if (options.scope) { + authUrl.searchParams.set("scope", options.scope); + } + + return { + authorizationUrl: authUrl, + codeVerifier: "test_verifier", + }; + }); + mockExchangeAuthorization.mockResolvedValue(mockOAuthTokens); + }); + + const renderAuthDebugger = (props: Partial = {}) => { + const mergedProps = { + ...defaultProps, + ...props, + authState: { ...defaultAuthState, ...(props.authState || {}) }, + }; + return render( + + + , + ); + }; + + describe("Initial Rendering", () => { + it("should render the component with correct title", async () => { + await act(async () => { + renderAuthDebugger(); + }); + expect(screen.getByText("Authentication Settings")).toBeInTheDocument(); + }); + + it("should call onBack when Back button is clicked", async () => { + const onBack = jest.fn(); + await act(async () => { + renderAuthDebugger({ onBack }); + }); + fireEvent.click(screen.getByText("Back to Connect")); + expect(onBack).toHaveBeenCalled(); + }); + }); + + describe("OAuth Flow", () => { + it("should start OAuth flow when 'Guided OAuth Flow' is clicked", async () => { + await act(async () => { + renderAuthDebugger(); + }); + + await act(async () => { + fireEvent.click(screen.getByText("Guided OAuth Flow")); + }); + + expect(screen.getByText("OAuth Flow Progress")).toBeInTheDocument(); + }); + + it("should show error when OAuth flow is started without sseUrl", async () => { + const updateAuthState = jest.fn(); + await act(async () => { + renderAuthDebugger({ serverUrl: "", updateAuthState }); + }); + + await act(async () => { + fireEvent.click(screen.getByText("Guided OAuth Flow")); + }); + + expect(updateAuthState).toHaveBeenCalledWith({ + statusMessage: { + type: "error", + message: + "Please enter a server URL in the sidebar before authenticating", + }, + }); + }); + }); + + describe("Session Storage Integration", () => { + it("should load OAuth tokens from session storage", async () => { + // Mock the specific key for tokens with server URL + sessionStorageMock.getItem.mockImplementation((key) => { + if (key === "[https://example.com] mcp_tokens") { + return JSON.stringify(mockOAuthTokens); + } + return null; + }); + + await act(async () => { + renderAuthDebugger({ + authState: { + ...defaultAuthState, + oauthTokens: mockOAuthTokens, + }, + }); + }); + + await waitFor(() => { + expect(screen.getByText(/Access Token:/)).toBeInTheDocument(); + }); + }); + + it("should handle errors loading OAuth tokens from session storage", async () => { + // Mock console to avoid cluttering test output + const originalError = console.error; + console.error = jest.fn(); + + // Mock getItem to return invalid JSON for tokens + sessionStorageMock.getItem.mockImplementation((key) => { + if (key === "[https://example.com] mcp_tokens") { + return "invalid json"; + } + return null; + }); + + await act(async () => { + renderAuthDebugger(); + }); + + // Component should still render despite the error + expect(screen.getByText("Authentication Settings")).toBeInTheDocument(); + + // Restore console.error + console.error = originalError; + }); + }); + + describe("OAuth State Management", () => { + it("should clear OAuth state when Clear button is clicked", async () => { + const updateAuthState = jest.fn(); + // Mock the session storage to return tokens for the specific key + sessionStorageMock.getItem.mockImplementation((key) => { + if (key === "[https://example.com] mcp_tokens") { + return JSON.stringify(mockOAuthTokens); + } + return null; + }); + + await act(async () => { + renderAuthDebugger({ + authState: { + ...defaultAuthState, + oauthTokens: mockOAuthTokens, + }, + updateAuthState, + }); + }); + + await act(async () => { + fireEvent.click(screen.getByText("Clear OAuth State")); + }); + + expect(updateAuthState).toHaveBeenCalledWith({ + oauthTokens: null, + oauthStep: "metadata_discovery", + latestError: null, + oauthClientInfo: null, + oauthMetadata: null, + authorizationCode: "", + validationError: null, + statusMessage: { + type: "success", + message: "OAuth tokens cleared successfully", + }, + }); + + // Verify session storage was cleared + expect(sessionStorageMock.removeItem).toHaveBeenCalled(); + }); + }); + + describe("OAuth Flow Steps", () => { + it("should handle OAuth flow step progression", async () => { + const updateAuthState = jest.fn(); + await act(async () => { + renderAuthDebugger({ + updateAuthState, + authState: { + ...defaultAuthState, + isInitiatingAuth: false, // Changed to false so button is enabled + oauthStep: "metadata_discovery", + }, + }); + }); + + // Verify metadata discovery step + expect(screen.getByText("Metadata Discovery")).toBeInTheDocument(); + + // Click Continue - this should trigger metadata discovery + await act(async () => { + fireEvent.click(screen.getByText("Continue")); + }); + + expect(mockDiscoverOAuthMetadata).toHaveBeenCalledWith( + "https://example.com", + ); + }); + + // Setup helper for OAuth authorization tests + const setupAuthorizationUrlTest = async (metadata: OAuthMetadata) => { + const updateAuthState = jest.fn(); + + // Mock the session storage to return metadata + sessionStorageMock.getItem.mockImplementation((key) => { + if (key === `[https://example.com] ${SESSION_KEYS.SERVER_METADATA}`) { + return JSON.stringify(metadata); + } + if ( + key === `[https://example.com] ${SESSION_KEYS.CLIENT_INFORMATION}` + ) { + return JSON.stringify(mockOAuthClientInfo); + } + return null; + }); + + await act(async () => { + renderAuthDebugger({ + updateAuthState, + authState: { + ...defaultAuthState, + isInitiatingAuth: false, + oauthStep: "authorization_redirect", + oauthMetadata: metadata, + oauthClientInfo: mockOAuthClientInfo, + }, + }); + }); + + // Click Continue to trigger authorization + await act(async () => { + fireEvent.click(screen.getByText("Continue")); + }); + + return updateAuthState; + }; + + it("should include scope in authorization URL when scopes_supported is present", async () => { + const metadataWithScopes = { + ...mockOAuthMetadata, + scopes_supported: ["read", "write", "admin"], + }; + + const updateAuthState = + await setupAuthorizationUrlTest(metadataWithScopes); + + // Wait for the updateAuthState to be called + await waitFor(() => { + expect(updateAuthState).toHaveBeenCalledWith( + expect.objectContaining({ + authorizationUrl: expect.stringContaining("scope="), + }), + ); + }); + }); + + it("should not include scope in authorization URL when scopes_supported is not present", async () => { + const updateAuthState = + await setupAuthorizationUrlTest(mockOAuthMetadata); + + // Wait for the updateAuthState to be called + await waitFor(() => { + expect(updateAuthState).toHaveBeenCalledWith( + expect.objectContaining({ + authorizationUrl: expect.not.stringContaining("scope="), + }), + ); + }); + }); + }); +}); diff --git a/client/src/lib/auth-types.ts b/client/src/lib/auth-types.ts new file mode 100644 index 00000000..ef32601a --- /dev/null +++ b/client/src/lib/auth-types.ts @@ -0,0 +1,38 @@ +import { + OAuthMetadata, + OAuthClientInformationFull, + OAuthClientInformation, + OAuthTokens, +} from "@modelcontextprotocol/sdk/shared/auth.js"; + +// OAuth flow steps +export type OAuthStep = + | "metadata_discovery" + | "client_registration" + | "authorization_redirect" + | "authorization_code" + | "token_request" + | "complete"; + +// Message types for inline feedback +export type MessageType = "success" | "error" | "info"; + +export interface StatusMessage { + type: MessageType; + message: string; +} + +// Single state interface for OAuth state +export interface AuthDebuggerState { + isInitiatingAuth: boolean; + oauthTokens: OAuthTokens | null; + loading: boolean; + oauthStep: OAuthStep; + oauthMetadata: OAuthMetadata | null; + oauthClientInfo: OAuthClientInformationFull | OAuthClientInformation | null; + authorizationUrl: string | null; + authorizationCode: string; + latestError: Error | null; + statusMessage: StatusMessage | null; + validationError: string | null; +} diff --git a/client/src/lib/auth.ts b/client/src/lib/auth.ts index 7ef31822..3e3516e0 100644 --- a/client/src/lib/auth.ts +++ b/client/src/lib/auth.ts @@ -4,11 +4,13 @@ import { OAuthClientInformation, OAuthTokens, OAuthTokensSchema, + OAuthClientMetadata, + OAuthMetadata, } from "@modelcontextprotocol/sdk/shared/auth.js"; import { SESSION_KEYS, getServerSpecificKey } from "./constants"; export class InspectorOAuthClientProvider implements OAuthClientProvider { - constructor(private serverUrl: string) { + constructor(public serverUrl: string) { // Save the server URL to session storage sessionStorage.setItem(SESSION_KEYS.SERVER_URL, serverUrl); } @@ -17,7 +19,7 @@ export class InspectorOAuthClientProvider implements OAuthClientProvider { return window.location.origin + "/oauth/callback"; } - get clientMetadata() { + get clientMetadata(): OAuthClientMetadata { return { redirect_uris: [this.redirectUrl], token_endpoint_auth_method: "none", @@ -101,3 +103,38 @@ export class InspectorOAuthClientProvider implements OAuthClientProvider { ); } } + +// Overrides debug URL and allows saving server OAuth metadata to +// display in debug UI. +export class DebugInspectorOAuthClientProvider extends InspectorOAuthClientProvider { + get redirectUrl(): string { + return `${window.location.origin}/oauth/callback/debug`; + } + + saveServerMetadata(metadata: OAuthMetadata) { + const key = getServerSpecificKey( + SESSION_KEYS.SERVER_METADATA, + this.serverUrl, + ); + sessionStorage.setItem(key, JSON.stringify(metadata)); + } + + getServerMetadata(): OAuthMetadata | null { + const key = getServerSpecificKey( + SESSION_KEYS.SERVER_METADATA, + this.serverUrl, + ); + const metadata = sessionStorage.getItem(key); + if (!metadata) { + return null; + } + return JSON.parse(metadata); + } + + clear() { + super.clear(); + sessionStorage.removeItem( + getServerSpecificKey(SESSION_KEYS.SERVER_METADATA, this.serverUrl), + ); + } +} diff --git a/client/src/lib/constants.ts b/client/src/lib/constants.ts index a03239ae..4c3e27aa 100644 --- a/client/src/lib/constants.ts +++ b/client/src/lib/constants.ts @@ -6,6 +6,7 @@ export const SESSION_KEYS = { SERVER_URL: "mcp_server_url", TOKENS: "mcp_tokens", CLIENT_INFORMATION: "mcp_client_information", + SERVER_METADATA: "mcp_server_metadata", } as const; // Generate server-specific session storage keys diff --git a/client/src/lib/oauth-state-machine.ts b/client/src/lib/oauth-state-machine.ts new file mode 100644 index 00000000..0678229c --- /dev/null +++ b/client/src/lib/oauth-state-machine.ts @@ -0,0 +1,181 @@ +import { OAuthStep, AuthDebuggerState } from "./auth-types"; +import { DebugInspectorOAuthClientProvider } from "./auth"; +import { + discoverOAuthMetadata, + registerClient, + startAuthorization, + exchangeAuthorization, +} from "@modelcontextprotocol/sdk/client/auth.js"; +import { OAuthMetadataSchema } from "@modelcontextprotocol/sdk/shared/auth.js"; + +export interface StateMachineContext { + state: AuthDebuggerState; + serverUrl: string; + provider: DebugInspectorOAuthClientProvider; + updateState: (updates: Partial) => void; +} + +export interface StateTransition { + canTransition: (context: StateMachineContext) => Promise; + execute: (context: StateMachineContext) => Promise; + nextStep: OAuthStep; +} + +// State machine transitions +export const oauthTransitions: Record = { + metadata_discovery: { + canTransition: async () => true, + execute: async (context) => { + const metadata = await discoverOAuthMetadata(context.serverUrl); + if (!metadata) { + throw new Error("Failed to discover OAuth metadata"); + } + const parsedMetadata = await OAuthMetadataSchema.parseAsync(metadata); + context.provider.saveServerMetadata(parsedMetadata); + context.updateState({ + oauthMetadata: parsedMetadata, + oauthStep: "client_registration", + }); + }, + nextStep: "client_registration", + }, + + client_registration: { + canTransition: async (context) => !!context.state.oauthMetadata, + execute: async (context) => { + const metadata = context.state.oauthMetadata!; + const clientMetadata = context.provider.clientMetadata; + + // Add all supported scopes to client registration + if (metadata.scopes_supported) { + clientMetadata.scope = metadata.scopes_supported.join(" "); + } + + const fullInformation = await registerClient(context.serverUrl, { + metadata, + clientMetadata, + }); + + context.provider.saveClientInformation(fullInformation); + context.updateState({ + oauthClientInfo: fullInformation, + oauthStep: "authorization_redirect", + }); + }, + nextStep: "authorization_redirect", + }, + + authorization_redirect: { + canTransition: async (context) => + !!context.state.oauthMetadata && !!context.state.oauthClientInfo, + execute: async (context) => { + const metadata = context.state.oauthMetadata!; + const clientInformation = context.state.oauthClientInfo!; + + let scope: string | undefined = undefined; + if (metadata.scopes_supported) { + scope = metadata.scopes_supported.join(" "); + } + + const { authorizationUrl, codeVerifier } = await startAuthorization( + context.serverUrl, + { + metadata, + clientInformation, + redirectUrl: context.provider.redirectUrl, + scope, + }, + ); + + context.provider.saveCodeVerifier(codeVerifier); + context.updateState({ + authorizationUrl: authorizationUrl.toString(), + oauthStep: "authorization_code", + }); + }, + nextStep: "authorization_code", + }, + + authorization_code: { + canTransition: async () => true, + execute: async (context) => { + if ( + !context.state.authorizationCode || + context.state.authorizationCode.trim() === "" + ) { + context.updateState({ + validationError: "You need to provide an authorization code", + }); + // Don't advance if no code + throw new Error("Authorization code required"); + } + context.updateState({ + validationError: null, + oauthStep: "token_request", + }); + }, + nextStep: "token_request", + }, + + token_request: { + canTransition: async (context) => { + return ( + !!context.state.authorizationCode && + !!context.provider.getServerMetadata() && + !!(await context.provider.clientInformation()) + ); + }, + execute: async (context) => { + const codeVerifier = context.provider.codeVerifier(); + const metadata = context.provider.getServerMetadata()!; + const clientInformation = (await context.provider.clientInformation())!; + + const tokens = await exchangeAuthorization(context.serverUrl, { + metadata, + clientInformation, + authorizationCode: context.state.authorizationCode, + codeVerifier, + redirectUri: context.provider.redirectUrl, + }); + + context.provider.saveTokens(tokens); + context.updateState({ + oauthTokens: tokens, + oauthStep: "complete", + }); + }, + nextStep: "complete", + }, + + complete: { + canTransition: async () => false, + execute: async () => { + // No-op for complete state + }, + nextStep: "complete", + }, +}; + +export class OAuthStateMachine { + constructor( + private serverUrl: string, + private updateState: (updates: Partial) => void, + ) {} + + async executeStep(state: AuthDebuggerState): Promise { + const provider = new DebugInspectorOAuthClientProvider(this.serverUrl); + const context: StateMachineContext = { + state, + serverUrl: this.serverUrl, + provider, + updateState: this.updateState, + }; + + const transition = oauthTransitions[state.oauthStep]; + if (!(await transition.canTransition(context))) { + throw new Error(`Cannot transition from ${state.oauthStep}`); + } + + await transition.execute(context); + } +} diff --git a/package-lock.json b/package-lock.json index e0e23c9f..92e89f74 100644 --- a/package-lock.json +++ b/package-lock.json @@ -17,7 +17,7 @@ "@modelcontextprotocol/inspector-cli": "^0.12.0", "@modelcontextprotocol/inspector-client": "^0.12.0", "@modelcontextprotocol/inspector-server": "^0.12.0", - "@modelcontextprotocol/sdk": "^1.11.0", + "@modelcontextprotocol/sdk": "^1.11.2", "concurrently": "^9.0.1", "open": "^10.1.0", "shell-quote": "^1.8.2", @@ -43,7 +43,7 @@ "version": "0.12.0", "license": "MIT", "dependencies": { - "@modelcontextprotocol/sdk": "^1.10.2", + "@modelcontextprotocol/sdk": "^1.11.0", "commander": "^13.1.0", "spawn-rx": "^5.1.2" }, @@ -66,7 +66,7 @@ "version": "0.12.0", "license": "MIT", "dependencies": { - "@modelcontextprotocol/sdk": "^1.10.2", + "@modelcontextprotocol/sdk": "^1.11.0", "@radix-ui/react-checkbox": "^1.1.4", "@radix-ui/react-dialog": "^1.1.3", "@radix-ui/react-icons": "^1.3.0", @@ -2004,10 +2004,9 @@ "link": true }, "node_modules/@modelcontextprotocol/sdk": { - "version": "1.11.0", - "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.11.0.tgz", - "integrity": "sha512-k/1pb70eD638anoi0e8wUGAlbMJXyvdV4p62Ko+EZ7eBe1xMx8Uhak1R5DgfoofsK5IBBnRwsYGTaLZl+6/+RQ==", - "license": "MIT", + "version": "1.11.2", + "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.11.2.tgz", + "integrity": "sha512-H9vwztj5OAqHg9GockCQC06k1natgcxWQSRpQcPJf6i5+MWBzfKkRtxGbjQf0X2ihii0ffLZCRGbYV2f2bjNCQ==", "dependencies": { "content-type": "^1.0.5", "cors": "^2.8.5", @@ -10937,7 +10936,7 @@ "version": "0.12.0", "license": "MIT", "dependencies": { - "@modelcontextprotocol/sdk": "^1.10.2", + "@modelcontextprotocol/sdk": "^1.11.0", "cors": "^2.8.5", "express": "^5.1.0", "ws": "^8.18.0", diff --git a/package.json b/package.json index 01beeac4..dcbd3bf1 100644 --- a/package.json +++ b/package.json @@ -43,7 +43,7 @@ "@modelcontextprotocol/inspector-cli": "^0.12.0", "@modelcontextprotocol/inspector-client": "^0.12.0", "@modelcontextprotocol/inspector-server": "^0.12.0", - "@modelcontextprotocol/sdk": "^1.11.0", + "@modelcontextprotocol/sdk": "^1.11.2", "concurrently": "^9.0.1", "open": "^10.1.0", "shell-quote": "^1.8.2",