@@ -110,6 +110,21 @@ def _generate_code_challenge(self, code_verifier: str) -> str:
110
110
digest = hashlib .sha256 (code_verifier .encode ()).digest ()
111
111
return base64 .urlsafe_b64encode (digest ).decode ().rstrip ("=" )
112
112
113
+ def _get_authorization_base_url (self , server_url : str ) -> str :
114
+ """
115
+ Determine the authorization base URL by discarding any path component.
116
+
117
+ Per MCP spec Section 2.3.2: "The authorization base URL MUST be determined
118
+ from the MCP server URL by discarding any existing path component."
119
+
120
+ Example: https://api.example.com/v1/mcp -> https://api.example.com
121
+ """
122
+ from urllib .parse import urlparse , urlunparse
123
+
124
+ parsed = urlparse (server_url )
125
+ # Discard path component by setting it to empty
126
+ return urlunparse ((parsed .scheme , parsed .netloc , "" , "" , "" , "" ))
127
+
113
128
async def _discover_oauth_metadata (self , server_url : str ) -> OAuthMetadata | None :
114
129
"""
115
130
Discovers OAuth metadata from the server's well-known endpoint.
@@ -120,7 +135,9 @@ async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | Non
120
135
Returns:
121
136
OAuthMetadata if found, None otherwise
122
137
"""
123
- url = urljoin (server_url , "/.well-known/oauth-authorization-server" )
138
+ # Get authorization base URL per MCP spec Section 2.3.2
139
+ auth_base_url = self ._get_authorization_base_url (server_url )
140
+ url = urljoin (auth_base_url , "/.well-known/oauth-authorization-server" )
124
141
headers = {"MCP-Protocol-Version" : LATEST_PROTOCOL_VERSION }
125
142
126
143
async with httpx .AsyncClient () as client :
@@ -171,24 +188,15 @@ async def _register_oauth_client(
171
188
if metadata and metadata .registration_endpoint :
172
189
registration_url = str (metadata .registration_endpoint )
173
190
else :
174
- registration_url = urljoin (server_url , "/register" )
191
+ # Use authorization base URL for fallback registration endpoint
192
+ auth_base_url = self ._get_authorization_base_url (server_url )
193
+ registration_url = urljoin (auth_base_url , "/register" )
175
194
176
- # Prepare registration data and adjust scope based on server metadata
195
+ # Prepare registration data
177
196
registration_data = client_metadata .model_dump (
178
197
by_alias = True , mode = "json" , exclude_none = True
179
198
)
180
199
181
- # If the server has supported scopes, use them instead of the requested scope
182
- if metadata and metadata .scopes_supported :
183
- # Use the first supported scope or "user" if available
184
- if "user" in metadata .scopes_supported :
185
- registration_data ["scope" ] = "user"
186
- else :
187
- registration_data ["scope" ] = metadata .scopes_supported [0 ]
188
- logger .debug (
189
- f"Adjusted scope to server-supported: { registration_data ['scope' ]} "
190
- )
191
-
192
200
async with httpx .AsyncClient () as client :
193
201
try :
194
202
response = await client .post (
@@ -252,6 +260,55 @@ def _has_valid_token(self) -> bool:
252
260
253
261
return True
254
262
263
+ async def _validate_token_scopes (self , token_response : OAuthToken ) -> None :
264
+ """
265
+ Validate that returned scopes are a subset of requested scopes.
266
+
267
+ Per OAuth 2.1 Section 3.3, the authorization server may issue a narrower
268
+ set of scopes than requested, but must not grant additional scopes.
269
+ """
270
+ if not token_response .scope :
271
+ # If no scope is returned, validation passes (server didn't grant anything extra)
272
+ return
273
+
274
+ # Get the originally requested scopes
275
+ requested_scopes : set [str ] = set ()
276
+
277
+ # Check for explicitly requested scopes from client metadata
278
+ if self .client_metadata .scope :
279
+ requested_scopes .update (self .client_metadata .scope .split ())
280
+
281
+ # If we have registered client info with specific scopes, use those
282
+ # (This handles cases where scopes were negotiated during registration)
283
+ if (
284
+ self ._client_info
285
+ and hasattr (self ._client_info , "scope" )
286
+ and self ._client_info .scope
287
+ ):
288
+ # Only override if the client metadata didn't have explicit scopes
289
+ # This represents what was actually registered/negotiated with the server
290
+ if not requested_scopes :
291
+ requested_scopes .update (self ._client_info .scope .split ())
292
+
293
+ # Parse returned scopes
294
+ returned_scopes : set [str ] = set (token_response .scope .split ())
295
+
296
+ # Validate that returned scopes are a subset of requested scopes
297
+ # Only enforce strict validation if we actually have requested scopes
298
+ if requested_scopes :
299
+ unauthorized_scopes : set [str ] = returned_scopes - requested_scopes
300
+ if unauthorized_scopes :
301
+ raise Exception (
302
+ f"Server granted unauthorized scopes: { unauthorized_scopes } . "
303
+ f"Requested: { requested_scopes } , Returned: { returned_scopes } "
304
+ )
305
+ else :
306
+ # If no scopes were originally requested (fell back to server defaults),
307
+ # accept whatever the server returned
308
+ logger .debug (
309
+ f"No specific scopes were requested, accepting server-granted scopes: { returned_scopes } "
310
+ )
311
+
255
312
async def initialize (self ) -> None :
256
313
"""Initialize the auth handler by loading stored tokens and client info."""
257
314
self ._current_tokens = await self .storage .get_tokens ()
@@ -307,7 +364,9 @@ async def _perform_oauth_flow(self) -> None:
307
364
if self ._metadata and self ._metadata .authorization_endpoint :
308
365
auth_url_base = str (self ._metadata .authorization_endpoint )
309
366
else :
310
- auth_url_base = urljoin (self .server_url , "/authorize" )
367
+ # Use authorization base URL for fallback authorization endpoint
368
+ auth_base_url = self ._get_authorization_base_url (self .server_url )
369
+ auth_url_base = urljoin (auth_base_url , "/authorize" )
311
370
312
371
# Build authorization URL
313
372
auth_params = {
@@ -319,16 +378,16 @@ async def _perform_oauth_flow(self) -> None:
319
378
"code_challenge_method" : "S256" ,
320
379
}
321
380
322
- if hasattr (client_info , "scope" ) and client_info .scope :
323
- auth_params ["scope" ] = client_info .scope
324
- elif self ._metadata and self ._metadata .scopes_supported :
325
- # Use "user" if available, otherwise the first supported scope
326
- if "user" in self ._metadata .scopes_supported :
327
- auth_params ["scope" ] = "user"
328
- else :
329
- auth_params ["scope" ] = self ._metadata .scopes_supported [0 ]
330
- elif self .client_metadata .scope :
381
+ # Set scope parameter following OAuth 2.1 principles:
382
+ # 1. Use client's explicit request first (what developer wants)
383
+ # 2. Use registered client scope as fallback (what was negotiated)
384
+ # 3. No scope = let server decide (omit scope parameter)
385
+ if self .client_metadata .scope :
331
386
auth_params ["scope" ] = self .client_metadata .scope
387
+ elif hasattr (client_info , "scope" ) and client_info .scope :
388
+ auth_params ["scope" ] = client_info .scope
389
+ # If no scope specified anywhere, don't include scope parameter
390
+ # This lets the server grant default scopes per OAuth 2.1
332
391
333
392
auth_url = f"{ auth_url_base } ?{ urlencode (auth_params )} "
334
393
@@ -339,7 +398,7 @@ async def _perform_oauth_flow(self) -> None:
339
398
340
399
# Validate state parameter
341
400
if returned_state != auth_params ["state" ]:
342
- raise Exception ("State parameter mismatch - possible CSRF attack " )
401
+ raise Exception ("State parameter mismatch" )
343
402
344
403
if not auth_code :
345
404
raise Exception ("No authorization code received" )
@@ -355,7 +414,9 @@ async def _exchange_code_for_token(
355
414
if self ._metadata and self ._metadata .token_endpoint :
356
415
token_url = str (self ._metadata .token_endpoint )
357
416
else :
358
- token_url = urljoin (self .server_url , "/token" )
417
+ # Use authorization base URL for fallback token endpoint
418
+ auth_base_url = self ._get_authorization_base_url (self .server_url )
419
+ token_url = urljoin (auth_base_url , "/token" )
359
420
360
421
token_data = {
361
422
"grant_type" : "authorization_code" ,
@@ -384,6 +445,9 @@ async def _exchange_code_for_token(
384
445
# Parse and store tokens
385
446
token_response = OAuthToken .model_validate (response .json ())
386
447
448
+ # Validate returned scopes against requested scopes (OAuth 2.1 Section 3.3)
449
+ await self ._validate_token_scopes (token_response )
450
+
387
451
# Calculate expiry time if available
388
452
if token_response .expires_in :
389
453
self ._token_expiry_time = time .time () + token_response .expires_in
@@ -406,7 +470,9 @@ async def _refresh_access_token(self) -> bool:
406
470
if self ._metadata and self ._metadata .token_endpoint :
407
471
token_url = str (self ._metadata .token_endpoint )
408
472
else :
409
- token_url = urljoin (self .server_url , "/token" )
473
+ # Use authorization base URL for fallback token endpoint
474
+ auth_base_url = self ._get_authorization_base_url (self .server_url )
475
+ token_url = urljoin (auth_base_url , "/token" )
410
476
411
477
refresh_data = {
412
478
"grant_type" : "refresh_token" ,
@@ -433,6 +499,9 @@ async def _refresh_access_token(self) -> bool:
433
499
# Parse and store new tokens
434
500
token_response = OAuthToken .model_validate (response .json ())
435
501
502
+ # Validate returned scopes against requested scopes (OAuth 2.1 Section 3.3)
503
+ await self ._validate_token_scopes (token_response )
504
+
436
505
# Calculate expiry time if available
437
506
if token_response .expires_in :
438
507
self ._token_expiry_time = time .time () + token_response .expires_in
0 commit comments