Skip to content

Commit 6768efb

Browse files
committed
Hoist oauth token expiration check into bearer auth middleware
1 parent 8e45e97 commit 6768efb

File tree

3 files changed

+16
-20
lines changed

3 files changed

+16
-20
lines changed

src/mcp/server/auth/middleware/bearer_auth.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ async def authenticate(self, conn: HTTPConnection):
5353

5454
try:
5555
# Validate the token with the provider
56-
auth_info = await self.provider.verify_access_token(token)
56+
auth_info = await self.provider.load_access_token(token)
57+
58+
if not auth_info:
59+
raise InvalidTokenError("Invalid access token")
5760

5861
if auth_info.expires_at and auth_info.expires_at < int(time.time()):
5962
raise InvalidTokenError("Token has expired")

src/mcp/server/auth/provider.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,17 +183,15 @@ async def exchange_refresh_token(
183183
"""
184184
...
185185

186-
# TODO: consider methods to generate refresh tokens and access tokens
187-
188-
async def verify_access_token(self, token: str) -> AuthInfo:
186+
async def load_access_token(self, token: str) -> AuthInfo | None:
189187
"""
190188
Verifies an access token and returns information about it.
191189
192190
Args:
193191
token: The access token to verify.
194192
195193
Returns:
196-
Information about the verified token.
194+
Information about the verified token, or None if the token is invalid.
197195
"""
198196
...
199197

tests/server/fastmcp/auth/test_auth_integration.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,14 @@ async def exchange_refresh_token(
189189
refresh_token=new_refresh_token,
190190
)
191191

192-
async def verify_access_token(self, token: str) -> AuthInfo:
193-
# Check if token exists
194-
if token not in self.tokens:
195-
raise InvalidTokenError("Invalid access token")
196-
197-
# Get token info
198-
token_info = self.tokens[token]
192+
async def load_access_token(self, token: str) -> AuthInfo | None:
193+
token_info = self.tokens.get(token)
199194

200195
# Check if token is expired
201-
if token_info.expires_at < int(time.time()):
202-
raise InvalidTokenError("Access token has expired")
196+
# if token_info.expires_at < int(time.time()):
197+
# raise InvalidTokenError("Access token has expired")
203198

204-
return AuthInfo(
199+
return token_info and AuthInfo(
205200
token=token,
206201
client_id=token_info.client_id,
207202
scopes=token_info.scopes,
@@ -852,7 +847,8 @@ async def test_authorization_get(
852847
refresh_token = token_response["refresh_token"]
853848

854849
# Create a test client with the token
855-
auth_info = await mock_oauth_provider.verify_access_token(access_token)
850+
auth_info = await mock_oauth_provider.load_access_token(access_token)
851+
assert auth_info
856852
assert auth_info.client_id == client_info["client_id"]
857853
assert "read" in auth_info.scopes
858854
assert "write" in auth_info.scopes
@@ -888,10 +884,9 @@ async def test_authorization_get(
888884
assert response.status_code == 200
889885

890886
# Verify that the token was revoked
891-
with pytest.raises(InvalidTokenError):
892-
await mock_oauth_provider.verify_access_token(
893-
new_token_response["access_token"]
894-
)
887+
assert await mock_oauth_provider.load_access_token(
888+
new_token_response["access_token"]
889+
) is None
895890

896891

897892
class TestFastMCPWithAuth:

0 commit comments

Comments
 (0)