2
2
#
3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
- import asyncio
6
5
import inspect
7
6
from copy import deepcopy
8
- from typing import Any , Dict , List , Optional
7
+ from typing import Any , Dict , Iterator , List , Optional
9
8
10
- from haystack import component , default_from_dict , default_to_dict , logging
9
+ from haystack import component , default_from_dict , default_to_dict , logging , tracing
11
10
from haystack .components .generators .chat .types import ChatGenerator
12
11
from haystack .components .tools import ToolInvoker
12
+ from haystack .core .pipeline .async_pipeline import AsyncPipeline
13
+ from haystack .core .pipeline .pipeline import Pipeline
13
14
from haystack .core .serialization import component_to_dict
14
15
from haystack .dataclasses import ChatMessage
15
16
from haystack .dataclasses .state import State , _schema_from_dict , _schema_to_dict , _validate_schema
@@ -187,6 +188,26 @@ def from_dict(cls, data: Dict[str, Any]) -> "Agent":
187
188
188
189
return default_from_dict (cls , data )
189
190
191
+ def _prepare_generator_inputs (self , streaming_callback : Optional [StreamingCallbackT ] = None ) -> Dict [str , Any ]:
192
+ """Prepare inputs for the chat generator."""
193
+ generator_inputs = {"tools" : self .tools }
194
+ selected_callback = streaming_callback or self .streaming_callback
195
+ if selected_callback is not None :
196
+ generator_inputs ["streaming_callback" ] = selected_callback
197
+ return generator_inputs
198
+
199
+ def _create_agent_span (self ) -> Iterator [tracing .Span ]:
200
+ """Create a span for the agent run."""
201
+ return tracing .tracer .trace (
202
+ "haystack.agent.run" ,
203
+ tags = {
204
+ "haystack.agent.max_steps" : self .max_agent_steps ,
205
+ "haystack.agent.tools" : self .tools ,
206
+ "haystack.agent.exit_conditions" : self .exit_conditions ,
207
+ "haystack.agent.state_schema" : self .state_schema ,
208
+ },
209
+ )
210
+
190
211
def run (
191
212
self ,
192
213
messages : List [ChatMessage ],
@@ -205,48 +226,66 @@ def run(
205
226
if not self ._is_warmed_up and hasattr (self .chat_generator , "warm_up" ):
206
227
raise RuntimeError ("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'." )
207
228
208
- state = State (schema = self .state_schema , data = kwargs )
209
-
210
229
if self .system_prompt is not None :
211
230
messages = [ChatMessage .from_system (self .system_prompt )] + messages
231
+
232
+ input_data = deepcopy ({"messages" : messages , "streaming_callback" : streaming_callback , ** kwargs })
233
+
234
+ state = State (schema = self .state_schema , data = kwargs )
212
235
state .set ("messages" , messages )
213
236
214
- generator_inputs : Dict [str , Any ] = {"tools" : self .tools }
237
+ generator_inputs = self ._prepare_generator_inputs (streaming_callback = streaming_callback )
238
+
239
+ component_visits = dict .fromkeys (["chat_generator" , "tool_invoker" ], 0 )
240
+ with self ._create_agent_span () as span :
241
+ span .set_content_tag ("haystack.agent.input" , input_data )
242
+ counter = 0
243
+ while counter < self .max_agent_steps :
244
+ # 1. Call the ChatGenerator
245
+ llm_messages = Pipeline ._run_component (
246
+ component_name = "chat_generator" ,
247
+ component = {"instance" : self .chat_generator },
248
+ inputs = {"messages" : messages , ** generator_inputs },
249
+ component_visits = component_visits ,
250
+ parent_span = span ,
251
+ )["replies" ]
252
+ state .set ("messages" , llm_messages )
253
+
254
+ # 2. Check if any of the LLM responses contain a tool call
255
+ if not any (msg .tool_call for msg in llm_messages ):
256
+ counter += 1
257
+ break
215
258
216
- selected_callback = streaming_callback or self .streaming_callback
217
- if selected_callback is not None :
218
- generator_inputs ["streaming_callback" ] = selected_callback
259
+ # 3. Call the ToolInvoker
260
+ # We only send the messages from the LLM to the tool invoker
261
+ tool_invoker_result = Pipeline ._run_component (
262
+ component_name = "tool_invoker" ,
263
+ component = {"instance" : self ._tool_invoker },
264
+ inputs = {"messages" : llm_messages , "state" : state },
265
+ component_visits = component_visits ,
266
+ parent_span = span ,
267
+ )
268
+ tool_messages = tool_invoker_result ["tool_messages" ]
269
+ state = tool_invoker_result ["state" ]
270
+ state .set ("messages" , tool_messages )
219
271
220
- # Repeat until the exit condition is met
221
- counter = 0
222
- while counter < self .max_agent_steps :
223
- # 1. Call the ChatGenerator
224
- llm_messages = self .chat_generator .run (messages = messages , ** generator_inputs )["replies" ]
225
- state .set ("messages" , llm_messages )
226
-
227
- # 2. Check if any of the LLM responses contain a tool call
228
- if not any (msg .tool_call for msg in llm_messages ):
229
- return {** state .data }
230
-
231
- # 3. Call the ToolInvoker
232
- # We only send the messages from the LLM to the tool invoker
233
- tool_invoker_result = self ._tool_invoker .run (messages = llm_messages , state = state )
234
- tool_messages = tool_invoker_result ["tool_messages" ]
235
- state = tool_invoker_result ["state" ]
236
- state .set ("messages" , tool_messages )
237
-
238
- # 4. Check if any LLM message's tool call name matches an exit condition
239
- if self .exit_conditions != ["text" ] and self ._check_exit_conditions (llm_messages , tool_messages ):
240
- return {** state .data }
241
-
242
- # 5. Fetch the combined messages and send them back to the LLM
243
- messages = state .get ("messages" )
244
- counter += 1
245
-
246
- logger .warning (
247
- "Agent exceeded maximum agent steps of {max_agent_steps}, stopping." , max_agent_steps = self .max_agent_steps
248
- )
249
- return {** state .data }
272
+ # 4. Check if any LLM message's tool call name matches an exit condition
273
+ if self .exit_conditions != ["text" ] and self ._check_exit_conditions (llm_messages , tool_messages ):
274
+ counter += 1
275
+ break
276
+
277
+ # 5. Fetch the combined messages and send them back to the LLM
278
+ messages = state .get ("messages" )
279
+ counter += 1
280
+
281
+ if counter >= self .max_agent_steps :
282
+ logger .warning (
283
+ "Agent reached maximum agent steps of {max_agent_steps}, stopping." ,
284
+ max_agent_steps = self .max_agent_steps ,
285
+ )
286
+ span .set_content_tag ("haystack.agent.output" , state .data )
287
+ span .set_tag ("haystack.agent.steps_taken" , counter )
288
+ return state .data
250
289
251
290
async def run_async (
252
291
self ,
@@ -270,66 +309,70 @@ async def run_async(
270
309
if not self ._is_warmed_up and hasattr (self .chat_generator , "warm_up" ):
271
310
raise RuntimeError ("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'." )
272
311
273
- state = State (schema = self .state_schema , data = kwargs )
274
-
275
312
if self .system_prompt is not None :
276
313
messages = [ChatMessage .from_system (self .system_prompt )] + messages
277
- state .set ("messages" , messages )
278
314
279
- generator_inputs : Dict [ str , Any ] = { "tools " : self . tools }
315
+ input_data = deepcopy ({ "messages " : messages , "streaming_callback" : streaming_callback , ** kwargs })
280
316
281
- selected_callback = streaming_callback or self .streaming_callback
282
- if selected_callback is not None :
283
- generator_inputs ["streaming_callback" ] = selected_callback
317
+ state = State (schema = self .state_schema , data = kwargs )
318
+ state .set ("messages" , messages )
284
319
285
- # Repeat until the exit condition is met
286
- counter = 0
287
- while counter < self .max_agent_steps :
288
- # 1. Call the ChatGenerator
289
- # Check if the chat generator supports async execution
290
- if getattr (self .chat_generator , "__haystack_supports_async__" , False ):
291
- result = await self .chat_generator .run_async (messages = messages , ** generator_inputs ) # type: ignore[attr-defined]
292
- llm_messages = result ["replies" ]
293
- else :
294
- # Fall back to synchronous run if async is not available
295
- loop = asyncio .get_running_loop ()
296
- result = await loop .run_in_executor (
297
- None , lambda : self .chat_generator .run (messages = messages , ** generator_inputs )
320
+ generator_inputs = self ._prepare_generator_inputs (streaming_callback = streaming_callback )
321
+
322
+ component_visits = dict .fromkeys (["chat_generator" , "tool_invoker" ], 0 )
323
+ with self ._create_agent_span () as span :
324
+ span .set_content_tag ("haystack.agent.input" , input_data )
325
+ counter = 0
326
+ while counter < self .max_agent_steps :
327
+ # 1. Call the ChatGenerator
328
+ result = await AsyncPipeline ._run_component_async (
329
+ component_name = "chat_generator" ,
330
+ component = {"instance" : self .chat_generator },
331
+ component_inputs = {"messages" : messages , ** generator_inputs },
332
+ component_visits = component_visits ,
333
+ max_runs_per_component = self .max_agent_steps ,
334
+ parent_span = span ,
298
335
)
299
336
llm_messages = result ["replies" ]
337
+ state .set ("messages" , llm_messages )
338
+
339
+ # 2. Check if any of the LLM responses contain a tool call
340
+ if not any (msg .tool_call for msg in llm_messages ):
341
+ counter += 1
342
+ break
300
343
301
- state .set ("messages" , llm_messages )
302
-
303
- # 2. Check if any of the LLM responses contain a tool call
304
- if not any (msg .tool_call for msg in llm_messages ):
305
- return {** state .data }
306
-
307
- # 3. Call the ToolInvoker
308
- # We only send the messages from the LLM to the tool invoker
309
- # Check if the ToolInvoker supports async execution. Currently, it doesn't.
310
- if getattr (self ._tool_invoker , "__haystack_supports_async__" , False ):
311
- tool_invoker_result = await self ._tool_invoker .run_async (messages = llm_messages , state = state ) # type: ignore[attr-defined]
312
- else :
313
- loop = asyncio .get_running_loop ()
314
- tool_invoker_result = await loop .run_in_executor (
315
- None , lambda : self ._tool_invoker .run (messages = llm_messages , state = state )
344
+ # 3. Call the ToolInvoker
345
+ # We only send the messages from the LLM to the tool invoker
346
+ # Check if the ToolInvoker supports async execution. Currently, it doesn't.
347
+ tool_invoker_result = await AsyncPipeline ._run_component_async (
348
+ component_name = "tool_invoker" ,
349
+ component = {"instance" : self ._tool_invoker },
350
+ component_inputs = {"messages" : llm_messages , "state" : state },
351
+ component_visits = component_visits ,
352
+ max_runs_per_component = self .max_agent_steps ,
353
+ parent_span = span ,
316
354
)
317
- tool_messages = tool_invoker_result ["tool_messages" ]
318
- state = tool_invoker_result ["state" ]
319
- state .set ("messages" , tool_messages )
355
+ tool_messages = tool_invoker_result ["tool_messages" ]
356
+ state = tool_invoker_result ["state" ]
357
+ state .set ("messages" , tool_messages )
320
358
321
- # 4. Check if any LLM message's tool call name matches an exit condition
322
- if self .exit_conditions != ["text" ] and self ._check_exit_conditions (llm_messages , tool_messages ):
323
- return {** state .data }
359
+ # 4. Check if any LLM message's tool call name matches an exit condition
360
+ if self .exit_conditions != ["text" ] and self ._check_exit_conditions (llm_messages , tool_messages ):
361
+ counter += 1
362
+ break
324
363
325
- # 5. Fetch the combined messages and send them back to the LLM
326
- messages = state .get ("messages" )
327
- counter += 1
364
+ # 5. Fetch the combined messages and send them back to the LLM
365
+ messages = state .get ("messages" )
366
+ counter += 1
328
367
329
- logger .warning (
330
- "Agent exceeded maximum agent steps of {max_agent_steps}, stopping." , max_agent_steps = self .max_agent_steps
331
- )
332
- return {** state .data }
368
+ if counter >= self .max_agent_steps :
369
+ logger .warning (
370
+ "Agent reached maximum agent steps of {max_agent_steps}, stopping." ,
371
+ max_agent_steps = self .max_agent_steps ,
372
+ )
373
+ span .set_content_tag ("haystack.agent.output" , state .data )
374
+ span .set_tag ("haystack.agent.steps_taken" , counter )
375
+ return state .data
333
376
334
377
def _check_exit_conditions (self , llm_messages : List [ChatMessage ], tool_messages : List [ChatMessage ]) -> bool :
335
378
"""
0 commit comments