@@ -45,18 +45,19 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
45
45
class CallbackHandler (BaseHTTPRequestHandler ):
46
46
"""Simple HTTP handler to capture OAuth callback."""
47
47
48
- authorization_code = None
49
- state = None
50
- error = None
48
+ def __init__ (self , request , client_address , server , callback_data ):
49
+ """Initialize with callback data storage."""
50
+ self .callback_data = callback_data
51
+ super ().__init__ (request , client_address , server )
51
52
52
53
def do_GET (self ):
53
54
"""Handle GET request from OAuth redirect."""
54
55
parsed = urlparse (self .path )
55
56
query_params = parse_qs (parsed .query )
56
57
57
58
if "code" in query_params :
58
- CallbackHandler . authorization_code = query_params ["code" ][0 ]
59
- CallbackHandler . state = query_params .get ("state" , [None ])[0 ]
59
+ self . callback_data [ " authorization_code" ] = query_params ["code" ][0 ]
60
+ self . callback_data [ " state" ] = query_params .get ("state" , [None ])[0 ]
60
61
self .send_response (200 )
61
62
self .send_header ("Content-type" , "text/html" )
62
63
self .end_headers ()
@@ -70,7 +71,7 @@ def do_GET(self):
70
71
</html>
71
72
""" )
72
73
elif "error" in query_params :
73
- CallbackHandler . error = query_params ["error" ][0 ]
74
+ self . callback_data [ " error" ] = query_params ["error" ][0 ]
74
75
self .send_response (400 )
75
76
self .send_header ("Content-type" , "text/html" )
76
77
self .end_headers ()
@@ -101,10 +102,26 @@ def __init__(self, port=3000):
101
102
self .port = port
102
103
self .server = None
103
104
self .thread = None
105
+ self .callback_data = {
106
+ "authorization_code" : None ,
107
+ "state" : None ,
108
+ "error" : None
109
+ }
110
+
111
+ def _create_handler_with_data (self ):
112
+ """Create a handler class with access to callback data."""
113
+ callback_data = self .callback_data
114
+
115
+ class DataCallbackHandler (CallbackHandler ):
116
+ def __init__ (self , request , client_address , server ):
117
+ super ().__init__ (request , client_address , server , callback_data )
118
+
119
+ return DataCallbackHandler
104
120
105
121
def start (self ):
106
122
"""Start the callback server in a background thread."""
107
- self .server = HTTPServer (("localhost" , self .port ), CallbackHandler )
123
+ handler_class = self ._create_handler_with_data ()
124
+ self .server = HTTPServer (("localhost" , self .port ), handler_class )
108
125
self .thread = threading .Thread (target = self .server .serve_forever , daemon = True )
109
126
self .thread .start ()
110
127
print (f"🖥️ Started callback server on http://localhost:{ self .port } " )
@@ -121,12 +138,16 @@ def wait_for_callback(self, timeout=300):
121
138
"""Wait for OAuth callback with timeout."""
122
139
start_time = time .time ()
123
140
while time .time () - start_time < timeout :
124
- if CallbackHandler . authorization_code :
125
- return CallbackHandler . authorization_code
126
- elif CallbackHandler . error :
127
- raise Exception (f"OAuth error: { CallbackHandler . error } " )
141
+ if self . callback_data [ " authorization_code" ] :
142
+ return self . callback_data [ " authorization_code" ]
143
+ elif self . callback_data [ " error" ] :
144
+ raise Exception (f"OAuth error: { self . callback_data [ ' error' ] } " )
128
145
time .sleep (0.1 )
129
146
raise Exception ("Timeout waiting for OAuth callback" )
147
+
148
+ def get_state (self ):
149
+ """Get the received state parameter."""
150
+ return self .callback_data ["state" ]
130
151
131
152
132
153
class SimpleAuthClient :
@@ -153,7 +174,7 @@ async def callback_handler() -> tuple[str, str | None]:
153
174
print ("⏳ Waiting for authorization callback..." )
154
175
try :
155
176
auth_code = callback_server .wait_for_callback (timeout = 300 )
156
- return auth_code , CallbackHandler . state
177
+ return auth_code , callback_server . get_state ()
157
178
finally :
158
179
callback_server .stop ()
159
180
0 commit comments