Skip to content

Commit 8762d53

Browse files
Merge branch 'praboud/auth' into jerome/auth
Resolves conflicts in OAuth implementation by: - Preserving modern typing with TypedDict for responses - Combining validation improvements for auth endpoints - Maintaining consistent error handling pattern - Integrating tests for all authorization flows
2 parents 65db7b6 + 5e7a0bf commit 8762d53

File tree

18 files changed

+1345
-358
lines changed

18 files changed

+1345
-358
lines changed

CLAUDE.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ This document contains critical information about working with this codebase. Fo
104104
- Add None checks
105105
- Narrow string types
106106
- Match existing patterns
107+
- Pytest:
108+
- If the tests aren't finding the anyio pytest mark, try adding PYTEST_DISABLE_PLUGIN_AUTOLOAD=""
109+
to the start of the pytest run command eg:
110+
`PYTEST_DISABLE_PLUGIN_AUTOLOAD="" uv run --frozen pytest`
107111

108112
3. Best Practices
109113
- Check git status before commits

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"sse-starlette>=1.6.1",
3131
"pydantic-settings>=2.5.2",
3232
"uvicorn>=0.23.1",
33+
"python-multipart",
3334
]
3435

3536
[project.optional-dependencies]

src/mcp/server/auth/errors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from typing import TypedDict
88

9+
from pydantic import ValidationError
10+
911

1012
class OAuthErrorResponse(TypedDict):
1113
"""OAuth error response format."""
@@ -150,3 +152,7 @@ class InsufficientScopeError(OAuthError):
150152
"""
151153

152154
error_code = "insufficient_scope"
155+
156+
157+
def stringify_pydantic_error(validation_error: ValidationError) -> str:
158+
return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors())

src/mcp/server/auth/handlers/authorize.py

Lines changed: 187 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -4,164 +4,265 @@
44
Corresponds to TypeScript file: src/server/auth/handlers/authorize.ts
55
"""
66

7-
from typing import Literal
8-
from urllib.parse import urlencode, urlparse, urlunparse
7+
from typing import Callable, Literal, Optional, Union
8+
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
99

10-
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, ValidationError
10+
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
11+
from starlette.datastructures import FormData, QueryParams
1112
from starlette.requests import Request
1213
from starlette.responses import RedirectResponse, Response
1314

1415
from mcp.server.auth.errors import (
1516
InvalidClientError,
1617
InvalidRequestError,
1718
OAuthError,
19+
stringify_pydantic_error,
1820
)
19-
from mcp.server.auth.handlers.types import HandlerFn
20-
from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider
21+
from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider, construct_redirect_uri
22+
from mcp.shared.auth import OAuthClientInformationFull
23+
from mcp.server.auth.json_response import PydanticJSONResponse
2124

25+
import logging
2226

23-
class AuthorizationRequest(BaseModel):
24-
"""
25-
Model for the authorization request parameters.
27+
logger = logging.getLogger(__name__)
2628

27-
Corresponds to request schema in authorizationHandler in
28-
src/server/auth/handlers/authorize.ts
29-
"""
3029

30+
class AuthorizationRequest(BaseModel):
31+
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
3132
client_id: str = Field(..., description="The client ID")
3233
redirect_uri: AnyHttpUrl | None = Field(
33-
..., description="URL to redirect to after authorization"
34+
None, description="URL to redirect to after authorization"
3435
)
3536

37+
# see OAuthClientMetadata; we only support `code`
3638
response_type: Literal["code"] = Field(
3739
..., description="Must be 'code' for authorization code flow"
3840
)
3941
code_challenge: str = Field(..., description="PKCE code challenge")
4042
code_challenge_method: Literal["S256"] = Field(
41-
"S256", description="PKCE code challenge method"
43+
"S256", description="PKCE code challenge method, must be S256"
44+
)
45+
state: Optional[str] = Field(None, description="Optional state parameter")
46+
scope: Optional[str] = Field(
47+
None,
48+
description="Optional scope; if specified, should be "
49+
"a space-separated list of scope strings",
4250
)
43-
state: str | None = Field(None, description="Optional state parameter")
44-
scope: str | None = Field(None, description="Optional scope parameter")
45-
46-
class Config:
47-
extra = "ignore"
4851

4952

50-
def validate_scope(requested_scope: str | None, scope: str | None) -> list[str] | None:
53+
def validate_scope(
54+
requested_scope: str | None, client: OAuthClientInformationFull
55+
) -> list[str] | None:
5156
if requested_scope is None:
5257
return None
5358
requested_scopes = requested_scope.split(" ")
54-
allowed_scopes = [] if scope is None else scope.split(" ")
59+
allowed_scopes = [] if client.scope is None else client.scope.split(" ")
5560
for scope in requested_scopes:
5661
if scope not in allowed_scopes:
5762
raise InvalidRequestError(f"Client was not registered with scope {scope}")
5863
return requested_scopes
5964

6065

6166
def validate_redirect_uri(
62-
redirect_uri: AnyHttpUrl | None, redirect_uris: list[AnyHttpUrl]
67+
redirect_uri: AnyHttpUrl | None, client: OAuthClientInformationFull
6368
) -> AnyHttpUrl:
64-
if not redirect_uris:
65-
raise InvalidClientError("Client has no registered redirect URIs")
66-
6769
if redirect_uri is not None:
6870
# Validate redirect_uri against client's registered redirect URIs
69-
if redirect_uri not in redirect_uris:
71+
if redirect_uri not in client.redirect_uris:
7072
raise InvalidRequestError(
7173
f"Redirect URI '{redirect_uri}' not registered for client"
7274
)
7375
return redirect_uri
74-
elif len(redirect_uris) == 1:
75-
return redirect_uris[0]
76+
elif len(client.redirect_uris) == 1:
77+
return client.redirect_uris[0]
7678
else:
7779
raise InvalidRequestError(
7880
"redirect_uri must be specified when client has multiple registered URIs"
7981
)
8082

83+
ErrorCode = Literal[
84+
"invalid_request",
85+
"unauthorized_client",
86+
"access_denied",
87+
"unsupported_response_type",
88+
"invalid_scope",
89+
"server_error",
90+
"temporarily_unavailable"
91+
]
92+
93+
class ErrorResponse(BaseModel):
94+
error: ErrorCode
95+
error_description: str
96+
error_uri: Optional[AnyUrl] = None
97+
# must be set if provided in the request
98+
state: Optional[str]
99+
100+
def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> Optional[str]:
101+
if params is None:
102+
return None
103+
value = params.get(key)
104+
if isinstance(value, str):
105+
return value
106+
return None
81107

82-
def create_authorization_handler(provider: OAuthServerProvider) -> HandlerFn:
83-
"""
84-
Create a handler for the OAuth 2.0 Authorization endpoint.
85-
86-
Corresponds to authorizationHandler in src/server/auth/handlers/authorize.ts
108+
class AnyHttpUrlModel(RootModel):
109+
root: AnyHttpUrl
87110

88-
"""
89111

112+
def create_authorization_handler(provider: OAuthServerProvider) -> Callable:
90113
async def authorization_handler(request: Request) -> Response:
91-
"""
92-
Handler for the OAuth 2.0 Authorization endpoint.
93-
"""
94-
# Validate request parameters
114+
# implements authorization requests for grant_type=code;
115+
# see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
116+
117+
state = None
118+
redirect_uri = None
119+
client = None
120+
params = None
121+
122+
async def error_response(error: ErrorCode, error_description: str, attempt_load_client: bool = True):
123+
nonlocal client, redirect_uri, state
124+
if client is None and attempt_load_client:
125+
# make last-ditch attempt to load the client
126+
client_id = best_effort_extract_string("client_id", params)
127+
client = client_id and await provider.clients_store.get_client(client_id)
128+
if redirect_uri is None and client:
129+
# make last-ditch effort to load the redirect uri
130+
if params is not None and "redirect_uri" not in params:
131+
raw_redirect_uri = None
132+
else:
133+
raw_redirect_uri = AnyHttpUrlModel.model_validate(best_effort_extract_string("redirect_uri", params)).root
134+
try:
135+
redirect_uri = validate_redirect_uri(raw_redirect_uri, client)
136+
except (ValidationError, InvalidRequestError):
137+
pass
138+
if state is None:
139+
# make last-ditch effort to load state
140+
state = best_effort_extract_string("state", params)
141+
142+
error_resp = ErrorResponse(
143+
error=error,
144+
error_description=error_description,
145+
state=state,
146+
)
147+
148+
if redirect_uri and client:
149+
return RedirectResponse(
150+
url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)),
151+
status_code=302,
152+
headers={"Cache-Control": "no-store"},
153+
)
154+
else:
155+
return PydanticJSONResponse(
156+
status_code=400,
157+
content=error_resp,
158+
headers={"Cache-Control": "no-store"},
159+
)
160+
95161
try:
162+
# Parse request parameters
96163
if request.method == "GET":
97164
# Convert query_params to dict for pydantic validation
98-
params = dict(request.query_params)
99-
auth_request = AuthorizationRequest.model_validate(params)
165+
params = request.query_params
100166
else:
101167
# Parse form data for POST requests
102-
form_data = await request.form()
103-
params = dict(form_data)
168+
params = await request.form()
169+
170+
# Save state if it exists, even before validation
171+
state = best_effort_extract_string("state", params)
172+
173+
try:
104174
auth_request = AuthorizationRequest.model_validate(params)
105-
except ValidationError as e:
106-
raise InvalidRequestError(str(e))
107-
108-
# Get client information
109-
client = await provider.clients_store.get_client(auth_request.client_id)
110-
111-
if not client:
112-
raise InvalidClientError(f"Client ID '{auth_request.client_id}' not found")
113-
114-
# do validation which is dependent on the client configuration
115-
redirect_uri = validate_redirect_uri(
116-
auth_request.redirect_uri, client.redirect_uris
117-
)
118-
scopes = validate_scope(auth_request.scope, client.scope)
119-
120-
auth_params = AuthorizationParams(
121-
state=auth_request.state,
122-
scopes=scopes,
123-
code_challenge=auth_request.code_challenge,
124-
redirect_uri=redirect_uri,
125-
)
126-
127-
response = RedirectResponse(
128-
url="", status_code=302, headers={"Cache-Control": "no-store"}
129-
)
130-
131-
try:
132-
# Let the provider handle the authorization flow
133-
await provider.authorize(client, auth_params, response)
134-
135-
return response
136-
except Exception as e:
137-
return RedirectResponse(
138-
url=create_error_redirect(redirect_uri, e, auth_request.state),
139-
status_code=302,
140-
headers={"Cache-Control": "no-store"},
175+
state = auth_request.state # Update with validated state
176+
except ValidationError as validation_error:
177+
error: ErrorCode = "invalid_request"
178+
for e in validation_error.errors():
179+
if e['loc'] == ('response_type',) and e['type'] == 'literal_error':
180+
error = "unsupported_response_type"
181+
break
182+
return await error_response(error, stringify_pydantic_error(validation_error))
183+
184+
# Get client information
185+
client = await provider.clients_store.get_client(auth_request.client_id)
186+
if not client:
187+
# For client_id validation errors, return direct error (no redirect)
188+
return await error_response(
189+
error="invalid_request",
190+
error_description=f"Client ID '{auth_request.client_id}' not found",
191+
attempt_load_client=False,
192+
)
193+
194+
195+
# Validate redirect_uri against client's registered URIs
196+
try:
197+
redirect_uri = validate_redirect_uri(auth_request.redirect_uri, client)
198+
except InvalidRequestError as validation_error:
199+
# For redirect_uri validation errors, return direct error (no redirect)
200+
return await error_response(
201+
error="invalid_request",
202+
error_description=validation_error.message,
203+
)
204+
205+
# Validate scope - for scope errors, we can redirect
206+
try:
207+
scopes = validate_scope(auth_request.scope, client)
208+
except InvalidRequestError as validation_error:
209+
# For scope errors, redirect with error parameters
210+
return await error_response(
211+
error="invalid_scope",
212+
error_description=validation_error.message,
213+
)
214+
215+
# Setup authorization parameters
216+
auth_params = AuthorizationParams(
217+
state=state,
218+
scopes=scopes,
219+
code_challenge=auth_request.code_challenge,
220+
redirect_uri=redirect_uri,
141221
)
222+
223+
# Let the provider pick the next URI to redirect to
224+
response = RedirectResponse(
225+
url="", status_code=302, headers={"Cache-Control": "no-store"}
226+
)
227+
response.headers["location"] = await provider.authorize(
228+
client, auth_params
229+
)
230+
return response
231+
232+
except Exception as validation_error:
233+
# Catch-all for unexpected errors
234+
logger.exception("Unexpected error in authorization_handler", exc_info=validation_error)
235+
return await error_response(error="server_error", error_description="An unexpected error occurred")
142236

143237
return authorization_handler
144238

145239

146240
def create_error_redirect(
147-
redirect_uri: AnyUrl, error: Exception, state: str | None
241+
redirect_uri: AnyUrl, error: Union[Exception, ErrorResponse]
148242
) -> str:
149243
parsed_uri = urlparse(str(redirect_uri))
150-
if isinstance(error, OAuthError):
244+
245+
if isinstance(error, ErrorResponse):
246+
# Convert ErrorResponse to dict
247+
error_dict = error.model_dump(exclude_none=True)
248+
query_params = {}
249+
for key, value in error_dict.items():
250+
if value is not None:
251+
if key == "error_uri" and hasattr(value, "__str__"):
252+
query_params[key] = str(value)
253+
else:
254+
query_params[key] = value
255+
256+
elif isinstance(error, OAuthError):
151257
query_params = {"error": error.error_code, "error_description": str(error)}
152258
else:
153259
query_params = {
154-
"error": "internal_error",
260+
"error": "server_error",
155261
"error_description": "An unknown error occurred",
156262
}
157-
# TODO: should we add error_uri?
158-
# if error.error_uri:
159-
# query_params["error_uri"] = str(error.error_uri)
160-
if state:
161-
query_params["state"] = state
162263

163264
new_query = urlencode(query_params)
164265
if parsed_uri.query:
165266
new_query = f"{parsed_uri.query}&{new_query}"
166267

167-
return urlunparse(parsed_uri._replace(query=new_query))
268+
return urlunparse(parsed_uri._replace(query=new_query))

0 commit comments

Comments
 (0)