|
4 | 4 | Corresponds to TypeScript file: src/server/auth/handlers/authorize.ts
|
5 | 5 | """
|
6 | 6 |
|
7 |
| -from typing import Callable, Literal, Optional |
| 7 | +from typing import Callable, Literal, Optional, Union |
8 | 8 | from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
9 | 9 |
|
10 |
| -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, ValidationError |
| 10 | +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError |
| 11 | +from starlette.datastructures import FormData, QueryParams |
11 | 12 | from starlette.requests import Request
|
12 | 13 | from starlette.responses import RedirectResponse, Response
|
13 | 14 |
|
14 | 15 | from mcp.server.auth.errors import (
|
15 | 16 | InvalidClientError,
|
16 | 17 | InvalidRequestError,
|
17 | 18 | OAuthError,
|
| 19 | + stringify_pydantic_error, |
18 | 20 | )
|
19 |
| -from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider |
| 21 | +from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider, construct_redirect_uri |
20 | 22 | from mcp.shared.auth import OAuthClientInformationFull
|
| 23 | +from mcp.server.auth.json_response import PydanticJSONResponse |
21 | 24 |
|
22 | 25 | import logging
|
23 | 26 |
|
24 | 27 | logger = logging.getLogger(__name__)
|
25 | 28 |
|
26 | 29 |
|
27 | 30 | class AuthorizationRequest(BaseModel):
|
| 31 | + # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 |
28 | 32 | client_id: str = Field(..., description="The client ID")
|
29 | 33 | redirect_uri: AnyHttpUrl | None = Field(
|
30 |
| - ..., description="URL to redirect to after authorization" |
| 34 | + None, description="URL to redirect to after authorization" |
31 | 35 | )
|
32 | 36 |
|
33 | 37 | # see OAuthClientMetadata; we only support `code`
|
@@ -61,108 +65,200 @@ def validate_scope(
|
61 | 65 |
|
62 | 66 |
|
63 | 67 | def validate_redirect_uri(
|
64 |
| - auth_request: AuthorizationRequest, client: OAuthClientInformationFull |
| 68 | + redirect_uri: AnyHttpUrl | None, client: OAuthClientInformationFull |
65 | 69 | ) -> AnyHttpUrl:
|
66 |
| - if auth_request.redirect_uri is not None: |
| 70 | + if redirect_uri is not None: |
67 | 71 | # Validate redirect_uri against client's registered redirect URIs
|
68 |
| - if auth_request.redirect_uri not in client.redirect_uris: |
| 72 | + if redirect_uri not in client.redirect_uris: |
69 | 73 | raise InvalidRequestError(
|
70 |
| - f"Redirect URI '{auth_request.redirect_uri}' not registered for client" |
| 74 | + f"Redirect URI '{redirect_uri}' not registered for client" |
71 | 75 | )
|
72 |
| - return auth_request.redirect_uri |
| 76 | + return redirect_uri |
73 | 77 | elif len(client.redirect_uris) == 1:
|
74 | 78 | return client.redirect_uris[0]
|
75 | 79 | else:
|
76 | 80 | raise InvalidRequestError(
|
77 | 81 | "redirect_uri must be specified when client has multiple registered URIs"
|
78 | 82 | )
|
| 83 | +ErrorCode = Literal[ |
| 84 | + "invalid_request", |
| 85 | + "unauthorized_client", |
| 86 | + "access_denied", |
| 87 | + "unsupported_response_type", |
| 88 | + "invalid_scope", |
| 89 | + "server_error", |
| 90 | + "temporarily_unavailable" |
| 91 | + ] |
| 92 | +class ErrorResponse(BaseModel): |
| 93 | + error: ErrorCode |
| 94 | + error_description: str |
| 95 | + error_uri: Optional[AnyUrl] = None |
| 96 | + # must be set if provided in the request |
| 97 | + state: Optional[str] |
| 98 | + |
| 99 | +def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> Optional[str]: |
| 100 | + if params is None: |
| 101 | + return None |
| 102 | + value = params.get(key) |
| 103 | + if isinstance(value, str): |
| 104 | + return value |
| 105 | + return None |
| 106 | + |
| 107 | +class AnyHttpUrlModel(RootModel): |
| 108 | + root: AnyHttpUrl |
79 | 109 |
|
80 | 110 |
|
81 | 111 | def create_authorization_handler(provider: OAuthServerProvider) -> Callable:
|
82 |
| - """ |
83 |
| - Create a handler for the OAuth 2.0 Authorization endpoint. |
| 112 | + async def authorization_handler(request: Request) -> Response: |
| 113 | + # implements authorization requests for grant_type=code; |
| 114 | + # see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 |
84 | 115 |
|
85 |
| - Corresponds to authorizationHandler in src/server/auth/handlers/authorize.ts |
| 116 | + state = None |
| 117 | + redirect_uri = None |
| 118 | + client = None |
| 119 | + params = None |
86 | 120 |
|
87 |
| - """ |
| 121 | + async def error_response(error: ErrorCode, error_description: str, attempt_load_client: bool = True): |
| 122 | + nonlocal client, redirect_uri, state |
| 123 | + if client is None and attempt_load_client: |
| 124 | + # make last-ditch attempt to load the client |
| 125 | + client_id = best_effort_extract_string("client_id", params) |
| 126 | + client = client_id and await provider.clients_store.get_client(client_id) |
| 127 | + if redirect_uri is None and client: |
| 128 | + # make last-ditch effort to load the redirect uri |
| 129 | + if params is not None and "redirect_uri" not in params: |
| 130 | + raw_redirect_uri = None |
| 131 | + else: |
| 132 | + raw_redirect_uri = AnyHttpUrlModel.model_validate(best_effort_extract_string("redirect_uri", params)).root |
| 133 | + try: |
| 134 | + redirect_uri = validate_redirect_uri(raw_redirect_uri, client) |
| 135 | + except (ValidationError, InvalidRequestError): |
| 136 | + pass |
| 137 | + if state is None: |
| 138 | + # make last-ditch effort to load state |
| 139 | + state = best_effort_extract_string("state", params) |
88 | 140 |
|
89 |
| - async def authorization_handler(request: Request) -> Response: |
90 |
| - """ |
91 |
| - Handler for the OAuth 2.0 Authorization endpoint. |
92 |
| - """ |
93 |
| - # Validate request parameters |
| 141 | + error_resp = ErrorResponse( |
| 142 | + error=error, |
| 143 | + error_description=error_description, |
| 144 | + state=state, |
| 145 | + ) |
| 146 | + |
| 147 | + if redirect_uri and client: |
| 148 | + return RedirectResponse( |
| 149 | + url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)), |
| 150 | + status_code=302, |
| 151 | + headers={"Cache-Control": "no-store"}, |
| 152 | + ) |
| 153 | + else: |
| 154 | + return PydanticJSONResponse( |
| 155 | + status_code=400, |
| 156 | + content=error_resp, |
| 157 | + headers={"Cache-Control": "no-store"}, |
| 158 | + ) |
| 159 | + |
94 | 160 | try:
|
| 161 | + # Parse request parameters |
95 | 162 | if request.method == "GET":
|
96 | 163 | # Convert query_params to dict for pydantic validation
|
97 |
| - params = dict(request.query_params) |
98 |
| - auth_request = AuthorizationRequest.model_validate(params) |
| 164 | + params = request.query_params |
99 | 165 | else:
|
100 | 166 | # Parse form data for POST requests
|
101 |
| - form_data = await request.form() |
102 |
| - params = dict(form_data) |
| 167 | + params = await request.form() |
| 168 | + |
| 169 | + # Save state if it exists, even before validation |
| 170 | + state = best_effort_extract_string("state", params) |
| 171 | + |
| 172 | + try: |
103 | 173 | auth_request = AuthorizationRequest.model_validate(params)
|
104 |
| - except ValidationError as e: |
105 |
| - raise InvalidRequestError(str(e)) |
| 174 | + state = auth_request.state # Update with validated state |
| 175 | + except ValidationError as validation_error: |
| 176 | + error: ErrorCode = "invalid_request" |
| 177 | + for e in validation_error.errors(): |
| 178 | + if e['loc'] == ('response_type',) and e['type'] == 'literal_error': |
| 179 | + error = "unsupported_response_type" |
| 180 | + break |
| 181 | + return await error_response(error, stringify_pydantic_error(validation_error)) |
106 | 182 |
|
107 |
| - # Get client information |
108 |
| - try: |
| 183 | + # Get client information |
109 | 184 | client = await provider.clients_store.get_client(auth_request.client_id)
|
110 |
| - except OAuthError as e: |
111 |
| - # TODO: proper error rendering |
112 |
| - raise InvalidClientError(str(e)) |
113 |
| - |
114 |
| - if not client: |
115 |
| - raise InvalidClientError(f"Client ID '{auth_request.client_id}' not found") |
116 |
| - |
117 |
| - # do validation which is dependent on the client configuration |
118 |
| - redirect_uri = validate_redirect_uri(auth_request, client) |
119 |
| - scopes = validate_scope(auth_request.scope, client) |
120 |
| - |
121 |
| - auth_params = AuthorizationParams( |
122 |
| - state=auth_request.state, |
123 |
| - scopes=scopes, |
124 |
| - code_challenge=auth_request.code_challenge, |
125 |
| - redirect_uri=redirect_uri, |
126 |
| - ) |
| 185 | + if not client: |
| 186 | + # For client_id validation errors, return direct error (no redirect) |
| 187 | + return await error_response( |
| 188 | + error="invalid_request", |
| 189 | + error_description=f"Client ID '{auth_request.client_id}' not found", |
| 190 | + attempt_load_client=False, |
| 191 | + ) |
127 | 192 |
|
128 |
| - try: |
| 193 | + |
| 194 | + # Validate redirect_uri against client's registered URIs |
| 195 | + try: |
| 196 | + redirect_uri = validate_redirect_uri(auth_request.redirect_uri, client) |
| 197 | + except InvalidRequestError as validation_error: |
| 198 | + # For redirect_uri validation errors, return direct error (no redirect) |
| 199 | + return await error_response( |
| 200 | + error="invalid_request", |
| 201 | + error_description=validation_error.message, |
| 202 | + ) |
| 203 | + |
| 204 | + # Validate scope - for scope errors, we can redirect |
| 205 | + try: |
| 206 | + scopes = validate_scope(auth_request.scope, client) |
| 207 | + except InvalidRequestError as validation_error: |
| 208 | + # For scope errors, redirect with error parameters |
| 209 | + return await error_response( |
| 210 | + error="invalid_scope", |
| 211 | + error_description=validation_error.message, |
| 212 | + ) |
| 213 | + |
| 214 | + # Setup authorization parameters |
| 215 | + auth_params = AuthorizationParams( |
| 216 | + state=state, |
| 217 | + scopes=scopes, |
| 218 | + code_challenge=auth_request.code_challenge, |
| 219 | + redirect_uri=redirect_uri, |
| 220 | + ) |
| 221 | + |
129 | 222 | # Let the provider pick the next URI to redirect to
|
130 | 223 | response = RedirectResponse(
|
131 | 224 | url="", status_code=302, headers={"Cache-Control": "no-store"}
|
132 | 225 | )
|
133 | 226 | response.headers["location"] = await provider.authorize(
|
134 | 227 | client, auth_params
|
135 | 228 | )
|
136 |
| - |
137 | 229 | return response
|
138 |
| - except Exception as e: |
139 |
| - logger.exception("error from authorize()", exc_info=e) |
140 |
| - |
141 |
| - return RedirectResponse( |
142 |
| - url=create_error_redirect(redirect_uri, e, auth_request.state), |
143 |
| - status_code=302, |
144 |
| - headers={"Cache-Control": "no-store"}, |
145 |
| - ) |
| 230 | + |
| 231 | + except Exception as validation_error: |
| 232 | + # Catch-all for unexpected errors |
| 233 | + logger.exception("Unexpected error in authorization_handler", exc_info=validation_error) |
| 234 | + return await error_response(error="server_error", error_description="An unexpected error occurred") |
146 | 235 |
|
147 | 236 | return authorization_handler
|
148 | 237 |
|
149 | 238 |
|
150 | 239 | def create_error_redirect(
|
151 |
| - redirect_uri: AnyUrl, error: Exception, state: Optional[str] |
| 240 | + redirect_uri: AnyUrl, error: Union[Exception, ErrorResponse] |
152 | 241 | ) -> str:
|
153 | 242 | parsed_uri = urlparse(str(redirect_uri))
|
154 |
| - if isinstance(error, OAuthError): |
| 243 | + |
| 244 | + if isinstance(error, ErrorResponse): |
| 245 | + # Convert ErrorResponse to dict |
| 246 | + error_dict = error.model_dump(exclude_none=True) |
| 247 | + query_params = {} |
| 248 | + for key, value in error_dict.items(): |
| 249 | + if value is not None: |
| 250 | + if key == "error_uri" and hasattr(value, "__str__"): |
| 251 | + query_params[key] = str(value) |
| 252 | + else: |
| 253 | + query_params[key] = value |
| 254 | + |
| 255 | + elif isinstance(error, OAuthError): |
155 | 256 | query_params = {"error": error.error_code, "error_description": str(error)}
|
156 | 257 | else:
|
157 | 258 | query_params = {
|
158 |
| - "error": "internal_error", |
| 259 | + "error": "server_error", |
159 | 260 | "error_description": "An unknown error occurred",
|
160 | 261 | }
|
161 |
| - # TODO: should we add error_uri? |
162 |
| - # if error.error_uri: |
163 |
| - # query_params["error_uri"] = str(error.error_uri) |
164 |
| - if state: |
165 |
| - query_params["state"] = state |
166 | 262 |
|
167 | 263 | new_query = urlencode(query_params)
|
168 | 264 | if parsed_uri.query:
|
|
0 commit comments