Skip to content

relax validation #879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/servers/simple-auth/mcp_simple_auth/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ async def exchange_authorization_code(

return OAuthToken(
access_token=mcp_token,
token_type="bearer",
token_type="Bearer",
expires_in=3600,
scope=" ".join(authorization_code.scopes),
)
Expand Down
33 changes: 17 additions & 16 deletions src/mcp/shared/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Literal

from pydantic import AnyHttpUrl, BaseModel, Field
from pydantic import AnyHttpUrl, BaseModel, Field, field_validator


class OAuthToken(BaseModel):
Expand All @@ -9,11 +9,20 @@ class OAuthToken(BaseModel):
"""

access_token: str
token_type: Literal["bearer"] = "bearer"
token_type: Literal["Bearer"] = "Bearer"
expires_in: int | None = None
scope: str | None = None
refresh_token: str | None = None

@field_validator("token_type", mode="before")
@classmethod
def normalize_token_type(cls, v: str | None) -> str | None:
if isinstance(v, str):
# Bearer is title-cased in the spec, so we normalize it
# https://datatracker.ietf.org/doc/html/rfc6750#section-4
return v.title()
return v


class InvalidScopeError(Exception):
def __init__(self, message: str):
Expand Down Expand Up @@ -111,27 +120,19 @@ class OAuthMetadata(BaseModel):
token_endpoint: AnyHttpUrl
registration_endpoint: AnyHttpUrl | None = None
scopes_supported: list[str] | None = None
response_types_supported: list[Literal["code"]] = ["code"]
response_types_supported: list[str] = ["code"]
response_modes_supported: list[Literal["query", "fragment"]] | None = None
grant_types_supported: (
list[Literal["authorization_code", "refresh_token"]] | None
) = None
token_endpoint_auth_methods_supported: (
list[Literal["none", "client_secret_post"]] | None
) = None
grant_types_supported: list[str] | None = None
token_endpoint_auth_methods_supported: list[str] | None = None
token_endpoint_auth_signing_alg_values_supported: None = None
service_documentation: AnyHttpUrl | None = None
ui_locales_supported: list[str] | None = None
op_policy_uri: AnyHttpUrl | None = None
op_tos_uri: AnyHttpUrl | None = None
revocation_endpoint: AnyHttpUrl | None = None
revocation_endpoint_auth_methods_supported: (
list[Literal["client_secret_post"]] | None
) = None
revocation_endpoint_auth_methods_supported: list[str] | None = None
revocation_endpoint_auth_signing_alg_values_supported: None = None
introspection_endpoint: AnyHttpUrl | None = None
introspection_endpoint_auth_methods_supported: (
list[Literal["client_secret_post"]] | None
) = None
introspection_endpoint_auth_methods_supported: list[str] | None = None
introspection_endpoint_auth_signing_alg_values_supported: None = None
code_challenge_methods_supported: list[Literal["S256"]] | None = None
code_challenge_methods_supported: list[str] | None = None
34 changes: 20 additions & 14 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def oauth_client_info():
def oauth_token():
return OAuthToken(
access_token="test_access_token",
token_type="bearer",
token_type="Bearer",
expires_in=3600,
refresh_token="test_refresh_token",
scope="read write",
Expand Down Expand Up @@ -143,7 +143,8 @@ def test_generate_code_verifier(self, oauth_provider):
verifiers = {oauth_provider._generate_code_verifier() for _ in range(10)}
assert len(verifiers) == 10

def test_generate_code_challenge(self, oauth_provider):
@pytest.mark.anyio
async def test_generate_code_challenge(self, oauth_provider):
"""Test PKCE code challenge generation."""
verifier = "test_code_verifier_123"
challenge = oauth_provider._generate_code_challenge(verifier)
Expand All @@ -161,7 +162,8 @@ def test_generate_code_challenge(self, oauth_provider):
assert "+" not in challenge
assert "/" not in challenge

def test_get_authorization_base_url(self, oauth_provider):
@pytest.mark.anyio
async def test_get_authorization_base_url(self, oauth_provider):
"""Test authorization base URL extraction."""
# Test with path
assert (
Expand Down Expand Up @@ -348,11 +350,13 @@ async def test_register_oauth_client_failure(self, oauth_provider):
None,
)

def test_has_valid_token_no_token(self, oauth_provider):
@pytest.mark.anyio
async def test_has_valid_token_no_token(self, oauth_provider):
"""Test token validation with no token."""
assert not oauth_provider._has_valid_token()

def test_has_valid_token_valid(self, oauth_provider, oauth_token):
@pytest.mark.anyio
async def test_has_valid_token_valid(self, oauth_provider, oauth_token):
"""Test token validation with valid token."""
oauth_provider._current_tokens = oauth_token
oauth_provider._token_expiry_time = time.time() + 3600 # Future expiry
Expand All @@ -370,7 +374,7 @@ async def test_has_valid_token_expired(self, oauth_provider, oauth_token):
@pytest.mark.anyio
async def test_validate_token_scopes_no_scope(self, oauth_provider):
"""Test scope validation with no scope returned."""
token = OAuthToken(access_token="test", token_type="bearer")
token = OAuthToken(access_token="test", token_type="Bearer")

# Should not raise exception
await oauth_provider._validate_token_scopes(token)
Expand All @@ -381,7 +385,7 @@ async def test_validate_token_scopes_valid(self, oauth_provider, client_metadata
oauth_provider.client_metadata = client_metadata
token = OAuthToken(
access_token="test",
token_type="bearer",
token_type="Bearer",
scope="read write",
)

Expand All @@ -394,7 +398,7 @@ async def test_validate_token_scopes_subset(self, oauth_provider, client_metadat
oauth_provider.client_metadata = client_metadata
token = OAuthToken(
access_token="test",
token_type="bearer",
token_type="Bearer",
scope="read",
)

Expand All @@ -409,7 +413,7 @@ async def test_validate_token_scopes_unauthorized(
oauth_provider.client_metadata = client_metadata
token = OAuthToken(
access_token="test",
token_type="bearer",
token_type="Bearer",
scope="read write admin", # Includes unauthorized "admin"
)

Expand All @@ -423,7 +427,7 @@ async def test_validate_token_scopes_no_requested(self, oauth_provider):
oauth_provider.client_metadata.scope = None
token = OAuthToken(
access_token="test",
token_type="bearer",
token_type="Bearer",
scope="admin super",
)

Expand Down Expand Up @@ -530,7 +534,7 @@ async def test_refresh_access_token_success(

new_token = OAuthToken(
access_token="new_access_token",
token_type="bearer",
token_type="Bearer",
expires_in=3600,
refresh_token="new_refresh_token",
scope="read write",
Expand Down Expand Up @@ -563,7 +567,7 @@ async def test_refresh_access_token_no_refresh_token(self, oauth_provider):
"""Test token refresh with no refresh token."""
oauth_provider._current_tokens = OAuthToken(
access_token="test",
token_type="bearer",
token_type="Bearer",
# No refresh_token
)

Expand Down Expand Up @@ -756,7 +760,8 @@ async def test_async_auth_flow_no_token(self, oauth_provider):
# No Authorization header should be added if no token
assert "Authorization" not in updated_request.headers

def test_scope_priority_client_metadata_first(
@pytest.mark.anyio
async def test_scope_priority_client_metadata_first(
self, oauth_provider, oauth_client_info
):
"""Test that client metadata scope takes priority."""
Expand Down Expand Up @@ -785,7 +790,8 @@ def test_scope_priority_client_metadata_first(

assert auth_params["scope"] == "read write"

def test_scope_priority_no_client_metadata_scope(
@pytest.mark.anyio
async def test_scope_priority_no_client_metadata_scope(
self, oauth_provider, oauth_client_info
):
"""Test that no scope parameter is set when client metadata has no scope."""
Expand Down
6 changes: 3 additions & 3 deletions tests/server/fastmcp/auth/test_auth_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def exchange_authorization_code(

return OAuthToken(
access_token=access_token,
token_type="bearer",
token_type="Bearer",
expires_in=3600,
scope="read write",
refresh_token=refresh_token,
Expand Down Expand Up @@ -160,7 +160,7 @@ async def exchange_refresh_token(

return OAuthToken(
access_token=new_access_token,
token_type="bearer",
token_type="Bearer",
expires_in=3600,
scope=" ".join(scopes) if scopes else " ".join(token_info.scopes),
refresh_token=new_refresh_token,
Expand Down Expand Up @@ -831,7 +831,7 @@ async def test_authorization_get(
assert "token_type" in token_response
assert "refresh_token" in token_response
assert "expires_in" in token_response
assert token_response["token_type"] == "bearer"
assert token_response["token_type"] == "Bearer"

# 5. Verify the access token
access_token = token_response["access_token"]
Expand Down
Loading