diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index 51f449113..799b9b517 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -214,7 +214,7 @@ async def exchange_authorization_code( return OAuthToken( access_token=mcp_token, - token_type="bearer", + token_type="Bearer", expires_in=3600, scope=" ".join(authorization_code.scopes), ) diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 22f8a971d..9c0fbbfb5 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from pydantic import AnyHttpUrl, BaseModel, Field +from pydantic import AnyHttpUrl, BaseModel, Field, field_validator class OAuthToken(BaseModel): @@ -9,11 +9,20 @@ class OAuthToken(BaseModel): """ access_token: str - token_type: Literal["bearer"] = "bearer" + token_type: Literal["Bearer"] = "Bearer" expires_in: int | None = None scope: str | None = None refresh_token: str | None = None + @field_validator("token_type", mode="before") + @classmethod + def normalize_token_type(cls, v: str | None) -> str | None: + if isinstance(v, str): + # Bearer is title-cased in the spec, so we normalize it + # https://datatracker.ietf.org/doc/html/rfc6750#section-4 + return v.title() + return v + class InvalidScopeError(Exception): def __init__(self, message: str): @@ -111,27 +120,19 @@ class OAuthMetadata(BaseModel): token_endpoint: AnyHttpUrl registration_endpoint: AnyHttpUrl | None = None scopes_supported: list[str] | None = None - response_types_supported: list[Literal["code"]] = ["code"] + response_types_supported: list[str] = ["code"] response_modes_supported: list[Literal["query", "fragment"]] | None = None - grant_types_supported: ( - list[Literal["authorization_code", "refresh_token"]] | None - ) = None - token_endpoint_auth_methods_supported: ( - list[Literal["none", "client_secret_post"]] | None - ) = None + grant_types_supported: list[str] | None = None + token_endpoint_auth_methods_supported: list[str] | None = None token_endpoint_auth_signing_alg_values_supported: None = None service_documentation: AnyHttpUrl | None = None ui_locales_supported: list[str] | None = None op_policy_uri: AnyHttpUrl | None = None op_tos_uri: AnyHttpUrl | None = None revocation_endpoint: AnyHttpUrl | None = None - revocation_endpoint_auth_methods_supported: ( - list[Literal["client_secret_post"]] | None - ) = None + revocation_endpoint_auth_methods_supported: list[str] | None = None revocation_endpoint_auth_signing_alg_values_supported: None = None introspection_endpoint: AnyHttpUrl | None = None - introspection_endpoint_auth_methods_supported: ( - list[Literal["client_secret_post"]] | None - ) = None + introspection_endpoint_auth_methods_supported: list[str] | None = None introspection_endpoint_auth_signing_alg_values_supported: None = None - code_challenge_methods_supported: list[Literal["S256"]] | None = None + code_challenge_methods_supported: list[str] | None = None diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 2edaff946..0a431a146 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -91,7 +91,7 @@ def oauth_client_info(): def oauth_token(): return OAuthToken( access_token="test_access_token", - token_type="bearer", + token_type="Bearer", expires_in=3600, refresh_token="test_refresh_token", scope="read write", @@ -143,7 +143,8 @@ def test_generate_code_verifier(self, oauth_provider): verifiers = {oauth_provider._generate_code_verifier() for _ in range(10)} assert len(verifiers) == 10 - def test_generate_code_challenge(self, oauth_provider): + @pytest.mark.anyio + async def test_generate_code_challenge(self, oauth_provider): """Test PKCE code challenge generation.""" verifier = "test_code_verifier_123" challenge = oauth_provider._generate_code_challenge(verifier) @@ -161,7 +162,8 @@ def test_generate_code_challenge(self, oauth_provider): assert "+" not in challenge assert "/" not in challenge - def test_get_authorization_base_url(self, oauth_provider): + @pytest.mark.anyio + async def test_get_authorization_base_url(self, oauth_provider): """Test authorization base URL extraction.""" # Test with path assert ( @@ -348,11 +350,13 @@ async def test_register_oauth_client_failure(self, oauth_provider): None, ) - def test_has_valid_token_no_token(self, oauth_provider): + @pytest.mark.anyio + async def test_has_valid_token_no_token(self, oauth_provider): """Test token validation with no token.""" assert not oauth_provider._has_valid_token() - def test_has_valid_token_valid(self, oauth_provider, oauth_token): + @pytest.mark.anyio + async def test_has_valid_token_valid(self, oauth_provider, oauth_token): """Test token validation with valid token.""" oauth_provider._current_tokens = oauth_token oauth_provider._token_expiry_time = time.time() + 3600 # Future expiry @@ -370,7 +374,7 @@ async def test_has_valid_token_expired(self, oauth_provider, oauth_token): @pytest.mark.anyio async def test_validate_token_scopes_no_scope(self, oauth_provider): """Test scope validation with no scope returned.""" - token = OAuthToken(access_token="test", token_type="bearer") + token = OAuthToken(access_token="test", token_type="Bearer") # Should not raise exception await oauth_provider._validate_token_scopes(token) @@ -381,7 +385,7 @@ async def test_validate_token_scopes_valid(self, oauth_provider, client_metadata oauth_provider.client_metadata = client_metadata token = OAuthToken( access_token="test", - token_type="bearer", + token_type="Bearer", scope="read write", ) @@ -394,7 +398,7 @@ async def test_validate_token_scopes_subset(self, oauth_provider, client_metadat oauth_provider.client_metadata = client_metadata token = OAuthToken( access_token="test", - token_type="bearer", + token_type="Bearer", scope="read", ) @@ -409,7 +413,7 @@ async def test_validate_token_scopes_unauthorized( oauth_provider.client_metadata = client_metadata token = OAuthToken( access_token="test", - token_type="bearer", + token_type="Bearer", scope="read write admin", # Includes unauthorized "admin" ) @@ -423,7 +427,7 @@ async def test_validate_token_scopes_no_requested(self, oauth_provider): oauth_provider.client_metadata.scope = None token = OAuthToken( access_token="test", - token_type="bearer", + token_type="Bearer", scope="admin super", ) @@ -530,7 +534,7 @@ async def test_refresh_access_token_success( new_token = OAuthToken( access_token="new_access_token", - token_type="bearer", + token_type="Bearer", expires_in=3600, refresh_token="new_refresh_token", scope="read write", @@ -563,7 +567,7 @@ async def test_refresh_access_token_no_refresh_token(self, oauth_provider): """Test token refresh with no refresh token.""" oauth_provider._current_tokens = OAuthToken( access_token="test", - token_type="bearer", + token_type="Bearer", # No refresh_token ) @@ -756,7 +760,8 @@ async def test_async_auth_flow_no_token(self, oauth_provider): # No Authorization header should be added if no token assert "Authorization" not in updated_request.headers - def test_scope_priority_client_metadata_first( + @pytest.mark.anyio + async def test_scope_priority_client_metadata_first( self, oauth_provider, oauth_client_info ): """Test that client metadata scope takes priority.""" @@ -785,7 +790,8 @@ def test_scope_priority_client_metadata_first( assert auth_params["scope"] == "read write" - def test_scope_priority_no_client_metadata_scope( + @pytest.mark.anyio + async def test_scope_priority_no_client_metadata_scope( self, oauth_provider, oauth_client_info ): """Test that no scope parameter is set when client metadata has no scope.""" diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index d237e860e..13b38a563 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -96,7 +96,7 @@ async def exchange_authorization_code( return OAuthToken( access_token=access_token, - token_type="bearer", + token_type="Bearer", expires_in=3600, scope="read write", refresh_token=refresh_token, @@ -160,7 +160,7 @@ async def exchange_refresh_token( return OAuthToken( access_token=new_access_token, - token_type="bearer", + token_type="Bearer", expires_in=3600, scope=" ".join(scopes) if scopes else " ".join(token_info.scopes), refresh_token=new_refresh_token, @@ -831,7 +831,7 @@ async def test_authorization_get( assert "token_type" in token_response assert "refresh_token" in token_response assert "expires_in" in token_response - assert token_response["token_type"] == "bearer" + assert token_response["token_type"] == "Bearer" # 5. Verify the access token access_token = token_response["access_token"]