@@ -142,18 +142,31 @@ async def index(request: Request):
142
142
return templates .TemplateResponse ("index.html" , {"request" : request })
143
143
144
144
145
- async def get_agent_variant (feature_manager , ai_client : AIProjectClient , thread_id : str ):
146
- if feature_manager :
147
- # Fetch the variant for the feature flag "my-agent" using thread_id as targeting Id
148
- agent_variant = await feature_manager .get_variant ("my-agent" , thread_id )
145
+ async def get_agent_with_feature_flag (
146
+ feature_manager : FeatureManager ,
147
+ ai_client : AIProjectClient ,
148
+ thread_id : str ,
149
+ default_agent : Agent = Depends (get_agent ),
150
+ ) -> Optional [Agent ]:
151
+ if thread_id and feature_manager :
152
+ # Fetch the variant for the feature flag "my-agent" using thread_id as targeting context
153
+ try :
154
+ agent_variant = await feature_manager .get_variant ("my-agent" , thread_id )
155
+ except Exception as e :
156
+ logger .error (f"Error fetching feature flag for thread_id={ thread_id } : { e } " )
157
+ agent_variant = None
158
+
159
+ # Retrieve the agent variant if available
149
160
if agent_variant and agent_variant .configuration :
150
161
try :
151
- await ai_client .agents .get_agent (agent_variant .configuration )
162
+ assigned_agent = await ai_client .agents .get_agent (agent_variant .configuration )
152
163
logger .info (f"Using variant={ agent_variant .name } with agent Id={ agent_variant .configuration } for thread_id={ thread_id } " )
153
- return agent_variant
164
+ return assigned_agent
154
165
except Exception as e :
155
- logger .error (f"Error retrieving agent variant with Id={ agent_variant .configuration } from AI project. { e } " )
156
- return None
166
+ logger .error (f"Error retrieving agent variant with Id={ agent_variant .configuration } from AI project. Fallback to default agent: { e } " )
167
+
168
+ # No agent variant found, fallback to default agent
169
+ return default_agent
157
170
158
171
async def get_result (thread_id : str , agent_id : str , ai_client : AIProjectClient ) -> AsyncGenerator [str , None ]:
159
172
logger .info (f"get_result invoked for thread_id={ thread_id } and agent_id={ agent_id } " )
@@ -180,15 +193,25 @@ async def get_result(thread_id: str, agent_id: str, ai_client : AIProjectClient)
180
193
@router .get ("/chat/history" )
181
194
async def history (
182
195
request : Request ,
183
- ai_client : AIProjectClient = Depends (get_ai_client )
196
+ default_agent : Agent = Depends (get_agent ),
197
+ ai_client : AIProjectClient = Depends (get_ai_client ),
198
+ feature_manager : FeatureManager = Depends (get_feature_manager ),
199
+ app_config : AzureAppConfigurationProvider = Depends (get_app_config ),
184
200
):
201
+ # Refresh config if configured
202
+ if app_config :
203
+ await app_config .refresh ()
204
+
185
205
# Retrieve the thread ID from the cookies (if available).
186
206
thread_id = request .cookies .get ('thread_id' )
187
207
agent_id = request .cookies .get ('agent_id' )
188
208
189
- # Attempt to get an existing thread. If not found, create a new one.
209
+ # Attempt to get agent from feature flag and fallback to default agent if not found.
210
+ agent = await get_agent_with_feature_flag (feature_manager , ai_client , thread_id , default_agent )
211
+
212
+ # Attempt to get an existing thread. If not found or agent has changed, create a new one.
190
213
try :
191
- if thread_id :
214
+ if thread_id and agent_id == agent . id :
192
215
logger .info (f"Retrieving thread with ID { thread_id } " )
193
216
thread = await ai_client .agents .get_thread (thread_id )
194
217
else :
@@ -199,7 +222,6 @@ async def history(
199
222
raise HTTPException (status_code = 400 , detail = f"Error handling thread: { e } " )
200
223
201
224
thread_id = thread .id
202
- agent_id = agent .id
203
225
204
226
# Create a new message from the user's input.
205
227
try :
@@ -229,21 +251,24 @@ async def history(
229
251
async def chat (
230
252
request : Request ,
231
253
ai_client : AIProjectClient = Depends (get_ai_client ),
232
- agent : Agent = Depends (get_agent ),
254
+ default_agent : Agent = Depends (get_agent ),
233
255
feature_manager : FeatureManager = Depends (get_feature_manager ),
234
256
app_config : AzureAppConfigurationProvider = Depends (get_app_config ),
235
257
):
236
258
# Refresh config if configured
237
259
if app_config :
238
- app_config .refresh ()
260
+ await app_config .refresh ()
239
261
240
262
# Retrieve the thread ID from the cookies (if available).
241
263
thread_id = request .cookies .get ('thread_id' )
242
264
agent_id = request .cookies .get ('agent_id' )
243
265
266
+ # Attempt to get agent from feature flag and fallback to default agent if not found.
267
+ agent = await get_agent_with_feature_flag (feature_manager , ai_client , thread_id , default_agent )
268
+
244
269
# Attempt to get an existing thread. If not found, create a new one.
245
270
try :
246
- if thread_id :
271
+ if thread_id and agent_id == agent . id :
247
272
logger .info (f"Retrieving thread with ID { thread_id } " )
248
273
thread = await ai_client .agents .get_thread (thread_id )
249
274
else :
@@ -254,8 +279,7 @@ async def chat(
254
279
raise HTTPException (status_code = 400 , detail = f"Error handling thread: { e } " )
255
280
256
281
thread_id = thread .id
257
- agent_variant = await get_agent_variant (feature_manager , ai_client , thread_id )
258
- agent_id = agent_variant .configuration if agent_variant else agent .id
282
+ agent_id = agent .id
259
283
260
284
# Parse the JSON from the request.
261
285
try :
@@ -293,10 +317,6 @@ async def chat(
293
317
response .set_cookie ("thread_id" , thread_id )
294
318
response .set_cookie ("agent_id" , agent_id )
295
319
296
- # Set the agent variant name in the response headers if available.
297
- if agent_variant :
298
- response .headers ["agent-variant" ] = str (agent_variant .name )
299
-
300
320
return response
301
321
302
322
0 commit comments