Skip to content

Commit c7a361e

Browse files
committed
small fixes
1 parent 36be8db commit c7a361e

File tree

3 files changed

+121
-150
lines changed

3 files changed

+121
-150
lines changed

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,19 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
4545
class CallbackHandler(BaseHTTPRequestHandler):
4646
"""Simple HTTP handler to capture OAuth callback."""
4747

48-
authorization_code = None
49-
state = None
50-
error = None
48+
def __init__(self, request, client_address, server, callback_data):
49+
"""Initialize with callback data storage."""
50+
self.callback_data = callback_data
51+
super().__init__(request, client_address, server)
5152

5253
def do_GET(self):
5354
"""Handle GET request from OAuth redirect."""
5455
parsed = urlparse(self.path)
5556
query_params = parse_qs(parsed.query)
5657

5758
if "code" in query_params:
58-
CallbackHandler.authorization_code = query_params["code"][0]
59-
CallbackHandler.state = query_params.get("state", [None])[0]
59+
self.callback_data["authorization_code"] = query_params["code"][0]
60+
self.callback_data["state"] = query_params.get("state", [None])[0]
6061
self.send_response(200)
6162
self.send_header("Content-type", "text/html")
6263
self.end_headers()
@@ -70,7 +71,7 @@ def do_GET(self):
7071
</html>
7172
""")
7273
elif "error" in query_params:
73-
CallbackHandler.error = query_params["error"][0]
74+
self.callback_data["error"] = query_params["error"][0]
7475
self.send_response(400)
7576
self.send_header("Content-type", "text/html")
7677
self.end_headers()
@@ -101,10 +102,26 @@ def __init__(self, port=3000):
101102
self.port = port
102103
self.server = None
103104
self.thread = None
105+
self.callback_data = {
106+
"authorization_code": None,
107+
"state": None,
108+
"error": None
109+
}
110+
111+
def _create_handler_with_data(self):
112+
"""Create a handler class with access to callback data."""
113+
callback_data = self.callback_data
114+
115+
class DataCallbackHandler(CallbackHandler):
116+
def __init__(self, request, client_address, server):
117+
super().__init__(request, client_address, server, callback_data)
118+
119+
return DataCallbackHandler
104120

105121
def start(self):
106122
"""Start the callback server in a background thread."""
107-
self.server = HTTPServer(("localhost", self.port), CallbackHandler)
123+
handler_class = self._create_handler_with_data()
124+
self.server = HTTPServer(("localhost", self.port), handler_class)
108125
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
109126
self.thread.start()
110127
print(f"🖥️ Started callback server on http://localhost:{self.port}")
@@ -121,12 +138,16 @@ def wait_for_callback(self, timeout=300):
121138
"""Wait for OAuth callback with timeout."""
122139
start_time = time.time()
123140
while time.time() - start_time < timeout:
124-
if CallbackHandler.authorization_code:
125-
return CallbackHandler.authorization_code
126-
elif CallbackHandler.error:
127-
raise Exception(f"OAuth error: {CallbackHandler.error}")
141+
if self.callback_data["authorization_code"]:
142+
return self.callback_data["authorization_code"]
143+
elif self.callback_data["error"]:
144+
raise Exception(f"OAuth error: {self.callback_data['error']}")
128145
time.sleep(0.1)
129146
raise Exception("Timeout waiting for OAuth callback")
147+
148+
def get_state(self):
149+
"""Get the received state parameter."""
150+
return self.callback_data["state"]
130151

131152

132153
class SimpleAuthClient:
@@ -153,7 +174,7 @@ async def callback_handler() -> tuple[str, str | None]:
153174
print("⏳ Waiting for authorization callback...")
154175
try:
155176
auth_code = callback_server.wait_for_callback(timeout=300)
156-
return auth_code, CallbackHandler.state
177+
return auth_code, callback_server.get_state()
157178
finally:
158179
callback_server.stop()
159180

0 commit comments

Comments
 (0)