Skip to content

Commit 3743e37

Browse files
committed
Improve validation for registration
1 parent ad74aee commit 3743e37

File tree

5 files changed

+129
-84
lines changed

5 files changed

+129
-84
lines changed

src/mcp/server/auth/errors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from typing import Dict
88

9+
from pydantic import ValidationError
10+
911

1012
class OAuthError(Exception):
1113
"""
@@ -143,3 +145,6 @@ class InsufficientScopeError(OAuthError):
143145
"""
144146

145147
error_code = "insufficient_scope"
148+
149+
def stringify_pydantic_error(validation_error: ValidationError) -> str:
150+
return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors())

src/mcp/server/auth/handlers/register.py

Lines changed: 57 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -6,103 +6,83 @@
66

77
import secrets
88
import time
9-
from typing import Callable
9+
from typing import Callable, Literal
1010
from uuid import uuid4
1111

12-
from pydantic import ValidationError
12+
from pydantic import BaseModel, ValidationError
1313
from starlette.requests import Request
1414
from starlette.responses import JSONResponse, Response
1515

1616
from mcp.server.auth.errors import (
1717
InvalidRequestError,
1818
OAuthError,
1919
ServerError,
20+
stringify_pydantic_error,
2021
)
2122
from mcp.server.auth.json_response import PydanticJSONResponse
2223
from mcp.server.auth.provider import OAuthRegisteredClientsStore
2324
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
2425

26+
class ErrorResponse(BaseModel):
27+
error: Literal["invalid_redirect_uri", "invalid_client_metadata", "invalid_software_statement", "unapproved_software_statement"]
28+
error_description: str
29+
2530

2631
def create_registration_handler(
2732
clients_store: OAuthRegisteredClientsStore, client_secret_expiry_seconds: int | None
2833
) -> Callable:
29-
"""
30-
Create a handler for OAuth 2.0 Dynamic Client Registration.
31-
32-
Corresponds to clientRegistrationHandler in src/server/auth/handlers/register.ts
33-
34-
Args:
35-
clients_store: The store for registered clients
36-
client_secret_expiry_seconds: Optional expiry time for client secrets
37-
38-
Returns:
39-
A Starlette endpoint handler function
40-
"""
41-
4234
async def registration_handler(request: Request) -> Response:
43-
"""
44-
Handler for the OAuth 2.0 Dynamic Client Registration endpoint.
45-
46-
Args:
47-
request: The Starlette request
48-
49-
Returns:
50-
JSON response with client information or error
51-
"""
35+
# Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1
5236
try:
5337
# Parse request body as JSON
54-
try:
55-
body = await request.json()
56-
client_metadata = OAuthClientMetadata.model_validate(body)
57-
except ValidationError as e:
58-
raise InvalidRequestError(f"Invalid client metadata: {str(e)}")
59-
60-
client_id = str(uuid4())
61-
client_secret = None
62-
if client_metadata.token_endpoint_auth_method != "none":
63-
# cryptographically secure random 32-byte hex string
64-
client_secret = secrets.token_hex(32)
65-
66-
client_id_issued_at = int(time.time())
67-
client_secret_expires_at = (
68-
client_id_issued_at + client_secret_expiry_seconds
69-
if client_secret_expiry_seconds is not None
70-
else None
71-
)
72-
73-
client_info = OAuthClientInformationFull(
74-
client_id=client_id,
75-
client_id_issued_at=client_id_issued_at,
76-
client_secret=client_secret,
77-
client_secret_expires_at=client_secret_expires_at,
78-
# passthrough information from the client request
79-
redirect_uris=client_metadata.redirect_uris,
80-
token_endpoint_auth_method=client_metadata.token_endpoint_auth_method,
81-
grant_types=client_metadata.grant_types,
82-
response_types=client_metadata.response_types,
83-
client_name=client_metadata.client_name,
84-
client_uri=client_metadata.client_uri,
85-
logo_uri=client_metadata.logo_uri,
86-
scope=client_metadata.scope,
87-
contacts=client_metadata.contacts,
88-
tos_uri=client_metadata.tos_uri,
89-
policy_uri=client_metadata.policy_uri,
90-
jwks_uri=client_metadata.jwks_uri,
91-
jwks=client_metadata.jwks,
92-
software_id=client_metadata.software_id,
93-
software_version=client_metadata.software_version,
94-
)
95-
# Register client
96-
client = await clients_store.register_client(client_info)
97-
if not client:
98-
raise ServerError("Failed to register client")
99-
100-
# Return client information
101-
return PydanticJSONResponse(content=client, status_code=201)
102-
103-
except OAuthError as e:
104-
# Handle OAuth errors
105-
status_code = 500 if isinstance(e, ServerError) else 400
106-
return JSONResponse(status_code=status_code, content=e.to_response_object())
38+
body = await request.json()
39+
client_metadata = OAuthClientMetadata.model_validate(body)
40+
except ValidationError as validation_error:
41+
return PydanticJSONResponse(content=ErrorResponse(
42+
error="invalid_client_metadata",
43+
error_description=stringify_pydantic_error(validation_error)
44+
), status_code=400)
45+
raise InvalidRequestError(f"Invalid client metadata: {str(e)}")
46+
47+
client_id = str(uuid4())
48+
client_secret = None
49+
if client_metadata.token_endpoint_auth_method != "none":
50+
# cryptographically secure random 32-byte hex string
51+
client_secret = secrets.token_hex(32)
52+
53+
client_id_issued_at = int(time.time())
54+
client_secret_expires_at = (
55+
client_id_issued_at + client_secret_expiry_seconds
56+
if client_secret_expiry_seconds is not None
57+
else None
58+
)
59+
60+
client_info = OAuthClientInformationFull(
61+
client_id=client_id,
62+
client_id_issued_at=client_id_issued_at,
63+
client_secret=client_secret,
64+
client_secret_expires_at=client_secret_expires_at,
65+
# passthrough information from the client request
66+
redirect_uris=client_metadata.redirect_uris,
67+
token_endpoint_auth_method=client_metadata.token_endpoint_auth_method,
68+
grant_types=client_metadata.grant_types,
69+
response_types=client_metadata.response_types,
70+
client_name=client_metadata.client_name,
71+
client_uri=client_metadata.client_uri,
72+
logo_uri=client_metadata.logo_uri,
73+
scope=client_metadata.scope,
74+
contacts=client_metadata.contacts,
75+
tos_uri=client_metadata.tos_uri,
76+
policy_uri=client_metadata.policy_uri,
77+
jwks_uri=client_metadata.jwks_uri,
78+
jwks=client_metadata.jwks,
79+
software_id=client_metadata.software_id,
80+
software_version=client_metadata.software_version,
81+
)
82+
# Register client
83+
client = await clients_store.register_client(client_info)
84+
85+
# Return client information
86+
return PydanticJSONResponse(content=client, status_code=201)
10787

10888
return registration_handler

src/mcp/server/auth/handlers/token.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from mcp.server.auth.errors import (
1616
InvalidRequestError,
17+
stringify_pydantic_error,
1718
)
1819
from mcp.server.auth.json_response import PydanticJSONResponse
1920
from mcp.server.auth.middleware.client_auth import (
@@ -74,7 +75,7 @@ async def token_handler(request: Request):
7475
except ValidationError as validation_error:
7576
return response(TokenErrorResponse(
7677
error="invalid_request",
77-
error_description="\n".join(e['msg'] for e in validation_error.errors())
78+
error_description=stringify_pydantic_error(validation_error)
7879

7980
))
8081
client_info = await client_authenticator(token_request)

src/mcp/server/auth/provider.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,12 @@ async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFul
7272

7373
async def register_client(
7474
self, client_info: OAuthClientInformationFull
75-
) -> Optional[OAuthClientInformationFull]:
75+
) -> None:
7676
"""
77-
Registers a new client and returns client information.
77+
Registers a new client
7878
7979
Args:
80-
metadata: The client metadata to register.
81-
82-
Returns:
83-
The client information, or None if registration failed.
80+
client_info: The client metadata to register.
8481
"""
8582
...
8683

tests/server/fastmcp/auth/test_auth_integration.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,68 @@ async def test_client_registration(
681681
# client_info["client_id"]
682682
# ) is not None
683683

684+
@pytest.mark.anyio
685+
async def test_client_registration_missing_required_fields(
686+
self, test_client: httpx.AsyncClient
687+
):
688+
"""Test client registration with missing required fields."""
689+
# Missing redirect_uris which is a required field
690+
client_metadata = {
691+
"client_name": "Test Client",
692+
"client_uri": "https://client.example.com",
693+
}
694+
695+
response = await test_client.post(
696+
"/register",
697+
json=client_metadata,
698+
)
699+
assert response.status_code == 400
700+
error_data = response.json()
701+
assert "error" in error_data
702+
assert error_data["error"] == "invalid_client_metadata"
703+
assert error_data["error_description"] == "redirect_uris: Field required"
704+
705+
@pytest.mark.anyio
706+
async def test_client_registration_invalid_uri(
707+
self, test_client: httpx.AsyncClient
708+
):
709+
"""Test client registration with invalid URIs."""
710+
# Invalid redirect_uri format
711+
client_metadata = {
712+
"redirect_uris": ["not-a-valid-uri"],
713+
"client_name": "Test Client",
714+
}
715+
716+
response = await test_client.post(
717+
"/register",
718+
json=client_metadata,
719+
)
720+
assert response.status_code == 400
721+
error_data = response.json()
722+
assert "error" in error_data
723+
assert error_data["error"] == "invalid_client_metadata"
724+
assert error_data["error_description"] == "redirect_uris.0: Input should be a valid URL, relative URL without a base"
725+
726+
@pytest.mark.anyio
727+
async def test_client_registration_empty_redirect_uris(
728+
self, test_client: httpx.AsyncClient
729+
):
730+
"""Test client registration with empty redirect_uris array."""
731+
client_metadata = {
732+
"redirect_uris": [], # Empty array
733+
"client_name": "Test Client",
734+
}
735+
736+
response = await test_client.post(
737+
"/register",
738+
json=client_metadata,
739+
)
740+
assert response.status_code == 400
741+
error_data = response.json()
742+
assert "error" in error_data
743+
assert error_data["error"] == "invalid_client_metadata"
744+
assert error_data["error_description"] == "redirect_uris: List should have at least 1 item after validation, not 0"
745+
684746
@pytest.mark.anyio
685747
async def test_authorize_form_post(
686748
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider

0 commit comments

Comments
 (0)