Skip to content

Commit f93ab34

Browse files
committed
fixes
1 parent e701d0e commit f93ab34

File tree

2 files changed

+97
-29
lines changed
  • examples/clients/simple-auth-client/mcp_simple_auth_client
  • src/mcp/client

2 files changed

+97
-29
lines changed

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,7 @@ async def callback_handler() -> tuple[str, str | None]:
162162
"redirect_uris": ["http://localhost:3000/callback"],
163163
"grant_types": ["authorization_code", "refresh_token"],
164164
"response_types": ["code"],
165-
"token_endpoint_auth_method": "client_secret_post", # Use client secret
166-
"scope": "read", # Default scope, will be updated
165+
"token_endpoint_auth_method": "client_secret_post",
167166
}
168167

169168
async def _default_redirect_handler(authorization_url: str) -> None:

src/mcp/client/auth.py

Lines changed: 96 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,21 @@ def _generate_code_challenge(self, code_verifier: str) -> str:
110110
digest = hashlib.sha256(code_verifier.encode()).digest()
111111
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
112112

113+
def _get_authorization_base_url(self, server_url: str) -> str:
114+
"""
115+
Determine the authorization base URL by discarding any path component.
116+
117+
Per MCP spec Section 2.3.2: "The authorization base URL MUST be determined
118+
from the MCP server URL by discarding any existing path component."
119+
120+
Example: https://api.example.com/v1/mcp -> https://api.example.com
121+
"""
122+
from urllib.parse import urlparse, urlunparse
123+
124+
parsed = urlparse(server_url)
125+
# Discard path component by setting it to empty
126+
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))
127+
113128
async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None:
114129
"""
115130
Discovers OAuth metadata from the server's well-known endpoint.
@@ -120,7 +135,9 @@ async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | Non
120135
Returns:
121136
OAuthMetadata if found, None otherwise
122137
"""
123-
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
138+
# Get authorization base URL per MCP spec Section 2.3.2
139+
auth_base_url = self._get_authorization_base_url(server_url)
140+
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
124141
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
125142

126143
async with httpx.AsyncClient() as client:
@@ -171,24 +188,15 @@ async def _register_oauth_client(
171188
if metadata and metadata.registration_endpoint:
172189
registration_url = str(metadata.registration_endpoint)
173190
else:
174-
registration_url = urljoin(server_url, "/register")
191+
# Use authorization base URL for fallback registration endpoint
192+
auth_base_url = self._get_authorization_base_url(server_url)
193+
registration_url = urljoin(auth_base_url, "/register")
175194

176-
# Prepare registration data and adjust scope based on server metadata
195+
# Prepare registration data
177196
registration_data = client_metadata.model_dump(
178197
by_alias=True, mode="json", exclude_none=True
179198
)
180199

181-
# If the server has supported scopes, use them instead of the requested scope
182-
if metadata and metadata.scopes_supported:
183-
# Use the first supported scope or "user" if available
184-
if "user" in metadata.scopes_supported:
185-
registration_data["scope"] = "user"
186-
else:
187-
registration_data["scope"] = metadata.scopes_supported[0]
188-
logger.debug(
189-
f"Adjusted scope to server-supported: {registration_data['scope']}"
190-
)
191-
192200
async with httpx.AsyncClient() as client:
193201
try:
194202
response = await client.post(
@@ -252,6 +260,55 @@ def _has_valid_token(self) -> bool:
252260

253261
return True
254262

263+
async def _validate_token_scopes(self, token_response: OAuthToken) -> None:
264+
"""
265+
Validate that returned scopes are a subset of requested scopes.
266+
267+
Per OAuth 2.1 Section 3.3, the authorization server may issue a narrower
268+
set of scopes than requested, but must not grant additional scopes.
269+
"""
270+
if not token_response.scope:
271+
# If no scope is returned, validation passes (server didn't grant anything extra)
272+
return
273+
274+
# Get the originally requested scopes
275+
requested_scopes: set[str] = set()
276+
277+
# Check for explicitly requested scopes from client metadata
278+
if self.client_metadata.scope:
279+
requested_scopes.update(self.client_metadata.scope.split())
280+
281+
# If we have registered client info with specific scopes, use those
282+
# (This handles cases where scopes were negotiated during registration)
283+
if (
284+
self._client_info
285+
and hasattr(self._client_info, "scope")
286+
and self._client_info.scope
287+
):
288+
# Only override if the client metadata didn't have explicit scopes
289+
# This represents what was actually registered/negotiated with the server
290+
if not requested_scopes:
291+
requested_scopes.update(self._client_info.scope.split())
292+
293+
# Parse returned scopes
294+
returned_scopes: set[str] = set(token_response.scope.split())
295+
296+
# Validate that returned scopes are a subset of requested scopes
297+
# Only enforce strict validation if we actually have requested scopes
298+
if requested_scopes:
299+
unauthorized_scopes: set[str] = returned_scopes - requested_scopes
300+
if unauthorized_scopes:
301+
raise Exception(
302+
f"Server granted unauthorized scopes: {unauthorized_scopes}. "
303+
f"Requested: {requested_scopes}, Returned: {returned_scopes}"
304+
)
305+
else:
306+
# If no scopes were originally requested (fell back to server defaults),
307+
# accept whatever the server returned
308+
logger.debug(
309+
f"No specific scopes were requested, accepting server-granted scopes: {returned_scopes}"
310+
)
311+
255312
async def initialize(self) -> None:
256313
"""Initialize the auth handler by loading stored tokens and client info."""
257314
self._current_tokens = await self.storage.get_tokens()
@@ -307,7 +364,9 @@ async def _perform_oauth_flow(self) -> None:
307364
if self._metadata and self._metadata.authorization_endpoint:
308365
auth_url_base = str(self._metadata.authorization_endpoint)
309366
else:
310-
auth_url_base = urljoin(self.server_url, "/authorize")
367+
# Use authorization base URL for fallback authorization endpoint
368+
auth_base_url = self._get_authorization_base_url(self.server_url)
369+
auth_url_base = urljoin(auth_base_url, "/authorize")
311370

312371
# Build authorization URL
313372
auth_params = {
@@ -319,16 +378,16 @@ async def _perform_oauth_flow(self) -> None:
319378
"code_challenge_method": "S256",
320379
}
321380

322-
if hasattr(client_info, "scope") and client_info.scope:
323-
auth_params["scope"] = client_info.scope
324-
elif self._metadata and self._metadata.scopes_supported:
325-
# Use "user" if available, otherwise the first supported scope
326-
if "user" in self._metadata.scopes_supported:
327-
auth_params["scope"] = "user"
328-
else:
329-
auth_params["scope"] = self._metadata.scopes_supported[0]
330-
elif self.client_metadata.scope:
381+
# Set scope parameter following OAuth 2.1 principles:
382+
# 1. Use client's explicit request first (what developer wants)
383+
# 2. Use registered client scope as fallback (what was negotiated)
384+
# 3. No scope = let server decide (omit scope parameter)
385+
if self.client_metadata.scope:
331386
auth_params["scope"] = self.client_metadata.scope
387+
elif hasattr(client_info, "scope") and client_info.scope:
388+
auth_params["scope"] = client_info.scope
389+
# If no scope specified anywhere, don't include scope parameter
390+
# This lets the server grant default scopes per OAuth 2.1
332391

333392
auth_url = f"{auth_url_base}?{urlencode(auth_params)}"
334393

@@ -339,7 +398,7 @@ async def _perform_oauth_flow(self) -> None:
339398

340399
# Validate state parameter
341400
if returned_state != auth_params["state"]:
342-
raise Exception("State parameter mismatch - possible CSRF attack")
401+
raise Exception("State parameter mismatch")
343402

344403
if not auth_code:
345404
raise Exception("No authorization code received")
@@ -355,7 +414,9 @@ async def _exchange_code_for_token(
355414
if self._metadata and self._metadata.token_endpoint:
356415
token_url = str(self._metadata.token_endpoint)
357416
else:
358-
token_url = urljoin(self.server_url, "/token")
417+
# Use authorization base URL for fallback token endpoint
418+
auth_base_url = self._get_authorization_base_url(self.server_url)
419+
token_url = urljoin(auth_base_url, "/token")
359420

360421
token_data = {
361422
"grant_type": "authorization_code",
@@ -384,6 +445,9 @@ async def _exchange_code_for_token(
384445
# Parse and store tokens
385446
token_response = OAuthToken.model_validate(response.json())
386447

448+
# Validate returned scopes against requested scopes (OAuth 2.1 Section 3.3)
449+
await self._validate_token_scopes(token_response)
450+
387451
# Calculate expiry time if available
388452
if token_response.expires_in:
389453
self._token_expiry_time = time.time() + token_response.expires_in
@@ -406,7 +470,9 @@ async def _refresh_access_token(self) -> bool:
406470
if self._metadata and self._metadata.token_endpoint:
407471
token_url = str(self._metadata.token_endpoint)
408472
else:
409-
token_url = urljoin(self.server_url, "/token")
473+
# Use authorization base URL for fallback token endpoint
474+
auth_base_url = self._get_authorization_base_url(self.server_url)
475+
token_url = urljoin(auth_base_url, "/token")
410476

411477
refresh_data = {
412478
"grant_type": "refresh_token",
@@ -433,6 +499,9 @@ async def _refresh_access_token(self) -> bool:
433499
# Parse and store new tokens
434500
token_response = OAuthToken.model_validate(response.json())
435501

502+
# Validate returned scopes against requested scopes (OAuth 2.1 Section 3.3)
503+
await self._validate_token_scopes(token_response)
504+
436505
# Calculate expiry time if available
437506
if token_response.expires_in:
438507
self._token_expiry_time = time.time() + token_response.expires_in

0 commit comments

Comments
 (0)