Skip to content

Commit 8e45e97

Browse files
committed
Improve /authorize validation & add tests
1 parent 3743e37 commit 8e45e97

File tree

3 files changed

+443
-104
lines changed

3 files changed

+443
-104
lines changed

src/mcp/server/auth/handlers/authorize.py

Lines changed: 157 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,34 @@
44
Corresponds to TypeScript file: src/server/auth/handlers/authorize.ts
55
"""
66

7-
from typing import Callable, Literal, Optional
7+
from typing import Callable, Literal, Optional, Union
88
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
99

10-
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, ValidationError
10+
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
11+
from starlette.datastructures import FormData, QueryParams
1112
from starlette.requests import Request
1213
from starlette.responses import RedirectResponse, Response
1314

1415
from mcp.server.auth.errors import (
1516
InvalidClientError,
1617
InvalidRequestError,
1718
OAuthError,
19+
stringify_pydantic_error,
1820
)
19-
from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider
21+
from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider, construct_redirect_uri
2022
from mcp.shared.auth import OAuthClientInformationFull
23+
from mcp.server.auth.json_response import PydanticJSONResponse
2124

2225
import logging
2326

2427
logger = logging.getLogger(__name__)
2528

2629

2730
class AuthorizationRequest(BaseModel):
31+
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
2832
client_id: str = Field(..., description="The client ID")
2933
redirect_uri: AnyHttpUrl | None = Field(
30-
..., description="URL to redirect to after authorization"
34+
None, description="URL to redirect to after authorization"
3135
)
3236

3337
# see OAuthClientMetadata; we only support `code`
@@ -61,108 +65,200 @@ def validate_scope(
6165

6266

6367
def validate_redirect_uri(
64-
auth_request: AuthorizationRequest, client: OAuthClientInformationFull
68+
redirect_uri: AnyHttpUrl | None, client: OAuthClientInformationFull
6569
) -> AnyHttpUrl:
66-
if auth_request.redirect_uri is not None:
70+
if redirect_uri is not None:
6771
# Validate redirect_uri against client's registered redirect URIs
68-
if auth_request.redirect_uri not in client.redirect_uris:
72+
if redirect_uri not in client.redirect_uris:
6973
raise InvalidRequestError(
70-
f"Redirect URI '{auth_request.redirect_uri}' not registered for client"
74+
f"Redirect URI '{redirect_uri}' not registered for client"
7175
)
72-
return auth_request.redirect_uri
76+
return redirect_uri
7377
elif len(client.redirect_uris) == 1:
7478
return client.redirect_uris[0]
7579
else:
7680
raise InvalidRequestError(
7781
"redirect_uri must be specified when client has multiple registered URIs"
7882
)
83+
ErrorCode = Literal[
84+
"invalid_request",
85+
"unauthorized_client",
86+
"access_denied",
87+
"unsupported_response_type",
88+
"invalid_scope",
89+
"server_error",
90+
"temporarily_unavailable"
91+
]
92+
class ErrorResponse(BaseModel):
93+
error: ErrorCode
94+
error_description: str
95+
error_uri: Optional[AnyUrl] = None
96+
# must be set if provided in the request
97+
state: Optional[str]
98+
99+
def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> Optional[str]:
100+
if params is None:
101+
return None
102+
value = params.get(key)
103+
if isinstance(value, str):
104+
return value
105+
return None
106+
107+
class AnyHttpUrlModel(RootModel):
108+
root: AnyHttpUrl
79109

80110

81111
def create_authorization_handler(provider: OAuthServerProvider) -> Callable:
82-
"""
83-
Create a handler for the OAuth 2.0 Authorization endpoint.
112+
async def authorization_handler(request: Request) -> Response:
113+
# implements authorization requests for grant_type=code;
114+
# see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
84115

85-
Corresponds to authorizationHandler in src/server/auth/handlers/authorize.ts
116+
state = None
117+
redirect_uri = None
118+
client = None
119+
params = None
86120

87-
"""
121+
async def error_response(error: ErrorCode, error_description: str, attempt_load_client: bool = True):
122+
nonlocal client, redirect_uri, state
123+
if client is None and attempt_load_client:
124+
# make last-ditch attempt to load the client
125+
client_id = best_effort_extract_string("client_id", params)
126+
client = client_id and await provider.clients_store.get_client(client_id)
127+
if redirect_uri is None and client:
128+
# make last-ditch effort to load the redirect uri
129+
if params is not None and "redirect_uri" not in params:
130+
raw_redirect_uri = None
131+
else:
132+
raw_redirect_uri = AnyHttpUrlModel.model_validate(best_effort_extract_string("redirect_uri", params)).root
133+
try:
134+
redirect_uri = validate_redirect_uri(raw_redirect_uri, client)
135+
except (ValidationError, InvalidRequestError):
136+
pass
137+
if state is None:
138+
# make last-ditch effort to load state
139+
state = best_effort_extract_string("state", params)
88140

89-
async def authorization_handler(request: Request) -> Response:
90-
"""
91-
Handler for the OAuth 2.0 Authorization endpoint.
92-
"""
93-
# Validate request parameters
141+
error_resp = ErrorResponse(
142+
error=error,
143+
error_description=error_description,
144+
state=state,
145+
)
146+
147+
if redirect_uri and client:
148+
return RedirectResponse(
149+
url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)),
150+
status_code=302,
151+
headers={"Cache-Control": "no-store"},
152+
)
153+
else:
154+
return PydanticJSONResponse(
155+
status_code=400,
156+
content=error_resp,
157+
headers={"Cache-Control": "no-store"},
158+
)
159+
94160
try:
161+
# Parse request parameters
95162
if request.method == "GET":
96163
# Convert query_params to dict for pydantic validation
97-
params = dict(request.query_params)
98-
auth_request = AuthorizationRequest.model_validate(params)
164+
params = request.query_params
99165
else:
100166
# Parse form data for POST requests
101-
form_data = await request.form()
102-
params = dict(form_data)
167+
params = await request.form()
168+
169+
# Save state if it exists, even before validation
170+
state = best_effort_extract_string("state", params)
171+
172+
try:
103173
auth_request = AuthorizationRequest.model_validate(params)
104-
except ValidationError as e:
105-
raise InvalidRequestError(str(e))
174+
state = auth_request.state # Update with validated state
175+
except ValidationError as validation_error:
176+
error: ErrorCode = "invalid_request"
177+
for e in validation_error.errors():
178+
if e['loc'] == ('response_type',) and e['type'] == 'literal_error':
179+
error = "unsupported_response_type"
180+
break
181+
return await error_response(error, stringify_pydantic_error(validation_error))
106182

107-
# Get client information
108-
try:
183+
# Get client information
109184
client = await provider.clients_store.get_client(auth_request.client_id)
110-
except OAuthError as e:
111-
# TODO: proper error rendering
112-
raise InvalidClientError(str(e))
113-
114-
if not client:
115-
raise InvalidClientError(f"Client ID '{auth_request.client_id}' not found")
116-
117-
# do validation which is dependent on the client configuration
118-
redirect_uri = validate_redirect_uri(auth_request, client)
119-
scopes = validate_scope(auth_request.scope, client)
120-
121-
auth_params = AuthorizationParams(
122-
state=auth_request.state,
123-
scopes=scopes,
124-
code_challenge=auth_request.code_challenge,
125-
redirect_uri=redirect_uri,
126-
)
185+
if not client:
186+
# For client_id validation errors, return direct error (no redirect)
187+
return await error_response(
188+
error="invalid_request",
189+
error_description=f"Client ID '{auth_request.client_id}' not found",
190+
attempt_load_client=False,
191+
)
127192

128-
try:
193+
194+
# Validate redirect_uri against client's registered URIs
195+
try:
196+
redirect_uri = validate_redirect_uri(auth_request.redirect_uri, client)
197+
except InvalidRequestError as validation_error:
198+
# For redirect_uri validation errors, return direct error (no redirect)
199+
return await error_response(
200+
error="invalid_request",
201+
error_description=validation_error.message,
202+
)
203+
204+
# Validate scope - for scope errors, we can redirect
205+
try:
206+
scopes = validate_scope(auth_request.scope, client)
207+
except InvalidRequestError as validation_error:
208+
# For scope errors, redirect with error parameters
209+
return await error_response(
210+
error="invalid_scope",
211+
error_description=validation_error.message,
212+
)
213+
214+
# Setup authorization parameters
215+
auth_params = AuthorizationParams(
216+
state=state,
217+
scopes=scopes,
218+
code_challenge=auth_request.code_challenge,
219+
redirect_uri=redirect_uri,
220+
)
221+
129222
# Let the provider pick the next URI to redirect to
130223
response = RedirectResponse(
131224
url="", status_code=302, headers={"Cache-Control": "no-store"}
132225
)
133226
response.headers["location"] = await provider.authorize(
134227
client, auth_params
135228
)
136-
137229
return response
138-
except Exception as e:
139-
logger.exception("error from authorize()", exc_info=e)
140-
141-
return RedirectResponse(
142-
url=create_error_redirect(redirect_uri, e, auth_request.state),
143-
status_code=302,
144-
headers={"Cache-Control": "no-store"},
145-
)
230+
231+
except Exception as validation_error:
232+
# Catch-all for unexpected errors
233+
logger.exception("Unexpected error in authorization_handler", exc_info=validation_error)
234+
return await error_response(error="server_error", error_description="An unexpected error occurred")
146235

147236
return authorization_handler
148237

149238

150239
def create_error_redirect(
151-
redirect_uri: AnyUrl, error: Exception, state: Optional[str]
240+
redirect_uri: AnyUrl, error: Union[Exception, ErrorResponse]
152241
) -> str:
153242
parsed_uri = urlparse(str(redirect_uri))
154-
if isinstance(error, OAuthError):
243+
244+
if isinstance(error, ErrorResponse):
245+
# Convert ErrorResponse to dict
246+
error_dict = error.model_dump(exclude_none=True)
247+
query_params = {}
248+
for key, value in error_dict.items():
249+
if value is not None:
250+
if key == "error_uri" and hasattr(value, "__str__"):
251+
query_params[key] = str(value)
252+
else:
253+
query_params[key] = value
254+
255+
elif isinstance(error, OAuthError):
155256
query_params = {"error": error.error_code, "error_description": str(error)}
156257
else:
157258
query_params = {
158-
"error": "internal_error",
259+
"error": "server_error",
159260
"error_description": "An unknown error occurred",
160261
}
161-
# TODO: should we add error_uri?
162-
# if error.error_uri:
163-
# query_params["error_uri"] = str(error.error_uri)
164-
if state:
165-
query_params["state"] = state
166262

167263
new_query = urlencode(query_params)
168264
if parsed_uri.query:

src/mcp/server/auth/provider.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,12 @@ async def revoke_token(
211211
"""
212212
...
213213

214-
def construct_redirect_uri(redirect_uri_base: str, authorization_code: AuthorizationCode, state: Optional[str]) -> str:
214+
def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str:
215215
parsed_uri = urlparse(redirect_uri_base)
216216
query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs]
217-
query_params.append(("code", authorization_code.code))
218-
if state:
219-
query_params.append(("state", state))
217+
for k, v in params.items():
218+
if v is not None:
219+
query_params.append((k, v))
220220

221221
redirect_uri = urlunparse(
222222
parsed_uri._replace(query=urlencode(query_params))

0 commit comments

Comments
 (0)