1
1
from __future__ import annotations
2
2
3
3
import asyncio
4
- import datetime
5
4
import json
6
5
import webbrowser
7
6
from pathlib import Path
20
19
OAuthMetadata as _MCPServerOAuthMetadata ,
21
20
)
22
21
from mcp .shared .auth import (
23
- OAuthToken as _MCPOAuthToken ,
22
+ OAuthToken as OAuthToken ,
24
23
)
25
- from pydantic import AnyHttpUrl , ValidationError , model_validator
26
- from typing_extensions import Self
24
+ from pydantic import AnyHttpUrl , ValidationError
27
25
28
26
from fastmcp .client .oauth_callback import (
29
27
create_oauth_callback_server ,
37
35
logger = get_logger (__name__ )
38
36
39
37
40
- class OAuthToken (_MCPOAuthToken ):
41
- """
42
- OAuth token that stores expiration as a datetime object
43
- """
44
-
45
- expires_at : datetime .datetime | None = None
46
-
47
- @model_validator (mode = "after" )
48
- def set_expires_at (self ) -> Self :
49
- if self .expires_in is not None and self .expires_at is None :
50
- now = datetime .datetime .now (datetime .timezone .utc )
51
- self .expires_at = now + datetime .timedelta (seconds = self .expires_in )
52
- return self
38
+ def default_cache_dir () -> Path :
39
+ return fastmcp_global_settings .home / "oauth-mcp-client-cache"
53
40
54
41
55
42
# Flexible OAuth models for real-world compatibility
@@ -140,9 +127,7 @@ class FileTokenStorage(TokenStorage):
140
127
def __init__ (self , server_url : str , cache_dir : Path | None = None ):
141
128
"""Initialize storage for a specific server URL."""
142
129
self .server_url = server_url
143
- self .cache_dir = (
144
- cache_dir or fastmcp_global_settings .home / "oauth-mcp-client-cache"
145
- )
130
+ self .cache_dir = cache_dir or default_cache_dir ()
146
131
self .cache_dir .mkdir (exist_ok = True , parents = True )
147
132
148
133
@staticmethod
@@ -172,21 +157,19 @@ async def get_tokens(self) -> OAuthToken | None:
172
157
173
158
try :
174
159
tokens = OAuthToken .model_validate_json (path .read_text ())
175
- now = datetime .datetime .now (datetime .timezone .utc )
176
- if tokens .expires_at is not None and tokens .expires_at <= now :
177
- logger .debug (f"Token expired for { self .get_base_url (self .server_url )} " )
178
- return None
160
+ # now = datetime.datetime.now(datetime.timezone.utc)
161
+ # if tokens.expires_at is not None and tokens.expires_at <= now:
162
+ # logger.debug(f"Token expired for {self.get_base_url(self.server_url)}")
163
+ # return None
179
164
return tokens
180
165
except (FileNotFoundError , json .JSONDecodeError , ValidationError ) as e :
181
166
logger .debug (
182
167
f"Could not load tokens for { self .get_base_url (self .server_url )} : { e } "
183
168
)
184
169
return None
185
170
186
- async def set_tokens (self , tokens : _MCPOAuthToken ) -> None :
171
+ async def set_tokens (self , tokens : OAuthToken ) -> None :
187
172
"""Save tokens to file storage."""
188
- # Convert to custom model with expiration datetime
189
- tokens = OAuthToken .model_validate (tokens .model_dump ())
190
173
path = self ._get_file_path ("tokens" )
191
174
path .write_text (tokens .model_dump_json (indent = 2 ))
192
175
logger .debug (f"Saved tokens for { self .get_base_url (self .server_url )} " )
@@ -195,7 +178,24 @@ async def get_client_info(self) -> OAuthClientInformationFull | None:
195
178
"""Load client information from file storage."""
196
179
path = self ._get_file_path ("client_info" )
197
180
try :
198
- return OAuthClientInformationFull .model_validate_json (path .read_text ())
181
+ client_info = OAuthClientInformationFull .model_validate_json (
182
+ path .read_text ()
183
+ )
184
+ # Check if we have corresponding valid tokens
185
+ # If no tokens exist, the OAuth flow was incomplete and we should
186
+ # force a fresh client registration
187
+ tokens = await self .get_tokens ()
188
+ if tokens is None :
189
+ logger .debug (
190
+ f"No tokens found for client info at { self .get_base_url (self .server_url )} . "
191
+ "OAuth flow may have been incomplete. Clearing client info to force fresh registration."
192
+ )
193
+ # Clear the incomplete client info
194
+ client_info_path = self ._get_file_path ("client_info" )
195
+ client_info_path .unlink (missing_ok = True )
196
+ return None
197
+
198
+ return client_info
199
199
except (FileNotFoundError , json .JSONDecodeError , ValidationError ) as e :
200
200
logger .debug (
201
201
f"Could not load client info for { self .get_base_url (self .server_url )} : { e } "
@@ -208,7 +208,7 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
208
208
path .write_text (client_info .model_dump_json (indent = 2 ))
209
209
logger .debug (f"Saved client info for { self .get_base_url (self .server_url )} " )
210
210
211
- def clear_cache (self ) -> None :
211
+ def clear (self ) -> None :
212
212
"""Clear all cached data for this server."""
213
213
file_types : list [Literal ["client_info" , "tokens" ]] = ["client_info" , "tokens" ]
214
214
for file_type in file_types :
@@ -219,7 +219,7 @@ def clear_cache(self) -> None:
219
219
@classmethod
220
220
def clear_all (cls , cache_dir : Path | None = None ) -> None :
221
221
"""Clear all cached data for all servers."""
222
- cache_dir = cache_dir or fastmcp_global_settings . home / "oauth-mcp-client-cache"
222
+ cache_dir = cache_dir or default_cache_dir ()
223
223
if not cache_dir .exists ():
224
224
return
225
225
0 commit comments