Skip to content

Commit 379cb7f

Browse files
committed
Added JWKS based token validation check
1 parent df2c91a commit 379cb7f

File tree

5 files changed

+155
-52
lines changed

5 files changed

+155
-52
lines changed

examples/servers/simple-auth-remote/mcp_simple_remote_auth/server.py

Lines changed: 108 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,118 @@
11
"""Simple MCP Server with GitHub OAuth Authentication."""
22

33
import logging
4-
import secrets
5-
import time
64
from typing import Any, Literal
75

86
import click
7+
import jwt
8+
import requests
99
from pydantic import AnyHttpUrl
1010
from pydantic_settings import BaseSettings, SettingsConfigDict
11-
from starlette.exceptions import HTTPException
12-
from starlette.requests import Request
13-
from starlette.responses import JSONResponse, RedirectResponse, Response
1411

15-
from mcp.server.auth.middleware.auth_context import get_access_token
1612
from mcp.server.auth.provider import (
1713
AccessToken,
18-
AuthorizationCode,
19-
AuthorizationParams,
20-
OAuthAuthorizationServerProvider,
21-
RefreshToken,
2214
TokenValidator,
23-
construct_redirect_uri,
2415
)
2516
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
2617
from mcp.server.fastmcp.server import FastMCP
27-
from mcp.shared._httpx_utils import create_mcp_http_client
28-
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
18+
from mcp.shared.auth import ProtectedResourceMetadata
2919

3020
logger = logging.getLogger(__name__)
3121

3222

23+
class TokenValidatorJWT(TokenValidator[AccessToken]):
24+
def __init__(self, resource_metadata: ProtectedResourceMetadata):
25+
self._resource_metadata = resource_metadata
26+
27+
async def validate_token(self, token: str) -> AccessToken | None:
28+
try:
29+
return await self.decode_token(token)
30+
except Exception as e:
31+
logger.error(f"Token validation failed: {e}")
32+
return None
33+
34+
async def _get_jwks_uri(self, auth_server: str) -> str:
35+
"""Get the JWKS URI from the OIDC or OAuth well-known configuration.
36+
37+
Args:
38+
auth_server: The base URL of the authorization server
39+
40+
Returns:
41+
The JWKS URI
42+
43+
Raises:
44+
ValueError: If the JWKS URI cannot be found in either OIDC or OAuth
45+
well-known configurations
46+
requests.RequestException: If there's an error fetching the configuration
47+
"""
48+
well_known_paths = [
49+
"/.well-known/openid-configuration", # OIDC well-known
50+
"/.well-known/oauth-authorization-server", # OAuth well-known
51+
]
52+
53+
last_error = None
54+
55+
for path in well_known_paths:
56+
try:
57+
config_url = f"https://{auth_server}{path}"
58+
response = requests.get(
59+
config_url,
60+
timeout=10, # Add timeout to prevent hanging
61+
headers={"Accept": "application/json"},
62+
)
63+
response.raise_for_status() # Raise an exception for bad status codes
64+
config = response.json()
65+
66+
# Try to get JWKS URI from the configuration
67+
jwks_uri = config.get("jwks_uri")
68+
if jwks_uri:
69+
return jwks_uri
70+
71+
except requests.RequestException as e:
72+
last_error = e
73+
logger.debug(f"Failed to fetch {path}: {e}")
74+
continue
75+
76+
# If we get here, we couldn't find a valid JWKS URI
77+
error_msg = "Could not find jwks_uri in OIDC or OAuth well-known configurations"
78+
logger.error(f"{error_msg}. Last error: {last_error}")
79+
raise ValueError(error_msg)
80+
81+
async def decode_token(self, token: str) -> AccessToken | None:
82+
try:
83+
auth_server = self._resource_metadata.authorization_servers[0]
84+
jwks_uri = await self._get_jwks_uri(auth_server)
85+
jwks_client = jwt.PyJWKClient(jwks_uri)
86+
signing_key = jwks_client.get_signing_key_from_jwt(token)
87+
88+
# Rest of your decode_token method remains the same
89+
payload = jwt.decode(
90+
token,
91+
key=signing_key.key,
92+
algorithms=["RS256"],
93+
audience=self._resource_metadata.resource,
94+
issuer=f"https://{auth_server}",
95+
options={
96+
"verify_signature": True,
97+
"verify_aud": True,
98+
"verify_iss": True,
99+
"verify_exp": True,
100+
"verify_nbf": True,
101+
"verify_iat": True,
102+
},
103+
)
104+
105+
return AccessToken(
106+
token=token,
107+
client_id=payload["client_id"],
108+
scopes=payload["scope"].split(" "),
109+
expires_at=payload["exp"],
110+
)
111+
except Exception as e:
112+
logger.error(f"Token validation failed: {e}")
113+
return None
114+
115+
33116
class ServerSettings(BaseSettings):
34117
"""Settings for the simple GitHub MCP server."""
35118

@@ -71,8 +154,18 @@ def create_simple_mcp_server(settings: ServerSettings) -> FastMCP:
71154
port=settings.port,
72155
debug=True,
73156
auth=auth_settings,
74-
token_validator=TokenValidator(),
75-
protected_resource_metadata={"resource": "asdasd", "authorization_servers": ["https://auth.devramp.ai"], "scopes_supported": ["user"]}
157+
token_validator=TokenValidatorJWT(
158+
ProtectedResourceMetadata(
159+
resource="asdasd",
160+
authorization_servers=["https://auth.devramp.ai"],
161+
scopes_supported=["user"],
162+
)
163+
),
164+
protected_resource_metadata={
165+
"resource": "asdasd",
166+
"authorization_servers": ["https://auth.devramp.ai"],
167+
"scopes_supported": ["user"],
168+
},
76169
)
77170

78171
@app.tool()

src/mcp/server/auth/middleware/bearer_auth.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
from starlette.requests import HTTPConnection
1111
from starlette.types import Receive, Scope, Send
1212

13-
from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, TokenValidator
13+
from mcp.server.auth.provider import (
14+
AccessToken,
15+
OAuthAuthorizationServerProvider,
16+
TokenValidator,
17+
)
1418

1519

1620
class AuthenticatedUser(SimpleUser):

src/mcp/server/auth/provider.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Generic, Literal, Protocol, TypeVar, Any
2+
from typing import Generic, Literal, Protocol, TypeVar
33
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
44

55
from pydantic import AnyHttpUrl, BaseModel
@@ -10,14 +10,6 @@
1010
)
1111

1212

13-
# Define type variables
14-
AccessTokenT = TypeVar('AccessTokenT', bound='AccessToken')
15-
16-
class TokenValidator(Generic[AccessTokenT], BaseModel):
17-
async def validate_token(self, token: str) -> AccessTokenT | None:
18-
...
19-
20-
2113
class AuthorizationParams(BaseModel):
2214
state: str | None
2315
scopes: list[str] | None
@@ -104,6 +96,13 @@ class TokenError(Exception):
10496
AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken)
10597

10698

99+
100+
class TokenValidator(BaseModel, Generic[AccessTokenT]):
101+
async def validate_token(self, token: str) -> AccessTokenT | None:
102+
...
103+
104+
105+
107106
class OAuthAuthorizationServerProvider(
108107
Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]
109108
):

src/mcp/server/fastmcp/server.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
import inspect
66
import re
77
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
8-
from contextlib import (
9-
AbstractAsyncContextManager,
10-
asynccontextmanager,
11-
)
8+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
129
from itertools import chain
1310
from typing import Any, Generic, Literal
1411

@@ -18,9 +15,9 @@
1815
from pydantic.networks import AnyUrl
1916
from pydantic_settings import BaseSettings, SettingsConfigDict
2017
from starlette.applications import Starlette
18+
from starlette.exceptions import HTTPException
2119
from starlette.middleware import Middleware
2220
from starlette.middleware.authentication import AuthenticationMiddleware
23-
from starlette.exceptions import HTTPException
2421
from starlette.requests import Request
2522
from starlette.responses import Response
2623
from starlette.routing import Mount, Route
@@ -32,12 +29,12 @@
3229
JWTBearerTokenAuthBackend,
3330
RequireAuthMiddleware,
3431
)
35-
from mcp.server.auth.provider import AccessToken, TokenValidator
36-
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
37-
from mcp.shared.auth import ProtectedResourceMetadata
38-
from mcp.server.auth.settings import (
39-
AuthSettings,
32+
from mcp.server.auth.provider import (
33+
AccessToken,
34+
OAuthAuthorizationServerProvider,
35+
TokenValidator,
4036
)
37+
from mcp.server.auth.settings import AuthSettings
4138
from mcp.server.fastmcp.exceptions import ResourceError
4239
from mcp.server.fastmcp.prompts import Prompt, PromptManager
4340
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
@@ -53,6 +50,7 @@
5350
from mcp.server.stdio import stdio_server
5451
from mcp.server.streamable_http import EventStore
5552
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
53+
from mcp.shared.auth import ProtectedResourceMetadata
5654
from mcp.shared.context import LifespanContextT, RequestContext
5755
from mcp.types import (
5856
AnyFunction,
@@ -143,8 +141,9 @@ def __init__(
143141
name: str | None = None,
144142
instructions: str | None = None,
145143
auth_server_details: dict[str, Any] | None = None,
146-
auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
147-
| None = None,
144+
auth_server_provider: (
145+
OAuthAuthorizationServerProvider[Any, Any, Any] | None
146+
) = None,
148147
protected_resource_metadata: dict[str, Any] | None = None,
149148
event_store: EventStore | None = None,
150149
token_validator: TokenValidator[AccessToken] | None = None,
@@ -154,9 +153,11 @@ def __init__(
154153
self._auth_server_details = auth_server_details
155154
self._protected_resource_metadata = None
156155
if protected_resource_metadata:
157-
self._protected_resource_metadata = ProtectedResourceMetadata(**protected_resource_metadata)
156+
self._protected_resource_metadata = ProtectedResourceMetadata(
157+
**protected_resource_metadata
158+
)
158159
self._token_validator = token_validator
159-
160+
160161
self._mcp_server = MCPServer(
161162
name=name or "FastMCP",
162163
instructions=instructions,
@@ -176,7 +177,9 @@ def __init__(
176177
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts
177178
)
178179
# don't do this check if protected_resource_metadata is not None
179-
if (self.settings.auth is not None) != (auth_server_provider is not None) and self._protected_resource_metadata is None:
180+
if (self.settings.auth is not None) != (
181+
auth_server_provider is not None
182+
) and self._protected_resource_metadata is None:
180183
# TODO: after we support separate authorization servers (see
181184
# https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284)
182185
# we should validate that if auth is enabled, we have either an
@@ -228,8 +231,13 @@ def session_manager(self) -> StreamableHTTPSessionManager:
228231
async def _serve_protected_resource_metadata(self, request: Request) -> Response:
229232
"""Serve the OAuth protected resource metadata."""
230233
if not self._protected_resource_metadata:
231-
raise HTTPException(status_code=404, detail="Protected resource metadata not configured")
232-
return Response(self._protected_resource_metadata.model_dump_json(), media_type="application/json")
234+
raise HTTPException(
235+
status_code=404, detail="Protected resource metadata not configured"
236+
)
237+
return Response(
238+
self._protected_resource_metadata.model_dump_json(),
239+
media_type="application/json",
240+
)
233241

234242
def run(
235243
self,
@@ -706,11 +714,10 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send):
706714

707715
# Create routes
708716
routes: list[Route | Mount] = []
709-
717+
710718
middleware: list[Middleware] = []
711719
required_scopes = []
712720

713-
714721
# Add auth endpoints if auth provider is configured
715722
if self._auth_server_provider:
716723
assert self.settings.auth
@@ -808,21 +815,20 @@ async def handle_streamable_http(
808815
routes: list[Route | Mount] = []
809816
middleware: list[Middleware] = []
810817
required_scopes = []
811-
print("Protected resource metadata: ", self._protected_resource_metadata)
812818
if self._protected_resource_metadata and self._token_validator:
813-
print("Adding protected resource metadata route")
814-
# only add the well-known route if the protected resource metadata is configured
819+
# only add the well-known route if the protected resource metadata is
820+
# configured
815821
routes.append(
816822
Route(
817823
"/.well-known/oauth-protected-resource",
818824
self._serve_protected_resource_metadata,
819-
methods=["GET"]
825+
methods=["GET"],
820826
)
821827
)
822-
# by default assuming that this would be a JWT Bearer Token;
823-
# Make this also optional somehow; may be as part of the protected resource metadata,
824-
# take a class for validting the token
825-
middleware= [
828+
# by default assuming that this would be a JWT Bearer Token;
829+
# Make this also optional somehow; may be as part of the protected resource
830+
# metadata, take a class for validating the token
831+
middleware = [
826832
Middleware(
827833
AuthenticationMiddleware,
828834
backend=JWTBearerTokenAuthBackend(

src/mcp/shared/auth.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pydantic import AnyHttpUrl, BaseModel, Field
44

5+
56
class ProtectedResourceMetadata(BaseModel):
67
# create a pydantic model with required params as resource, authorization_servers
78
resource: str

0 commit comments

Comments
 (0)