Skip to content

Commit 185e1c7

Browse files
sjrlvblagoje
andauthored
feat: Agent tracing (#9240)
* Agent tracing * Small changes * Some changes and refactoring * Refactoring to reuse code * Fix * Add reno * Fix tests * Fix tests * Fix linting * Refactor and add tracing support to run_async of Agent * Reduce duplicate code * Remove finalize_run * Use break instead of copying code three times * Adding a test * Add tracing unit tests * Make async tracing test actually run async * Increase test coverage * Unit test for traces in pipeline * Add cleanup * Fix proper indentation * PR comments * PR comments and new test * Update warning message * Update warning message --------- Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
1 parent 656fe6d commit 185e1c7

File tree

7 files changed

+651
-238
lines changed

7 files changed

+651
-238
lines changed

haystack/components/agents/agent.py

+130-87
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import asyncio
65
import inspect
76
from copy import deepcopy
8-
from typing import Any, Dict, List, Optional
7+
from typing import Any, Dict, Iterator, List, Optional
98

10-
from haystack import component, default_from_dict, default_to_dict, logging
9+
from haystack import component, default_from_dict, default_to_dict, logging, tracing
1110
from haystack.components.generators.chat.types import ChatGenerator
1211
from haystack.components.tools import ToolInvoker
12+
from haystack.core.pipeline.async_pipeline import AsyncPipeline
13+
from haystack.core.pipeline.pipeline import Pipeline
1314
from haystack.core.serialization import component_to_dict
1415
from haystack.dataclasses import ChatMessage
1516
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
@@ -187,6 +188,26 @@ def from_dict(cls, data: Dict[str, Any]) -> "Agent":
187188

188189
return default_from_dict(cls, data)
189190

191+
def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]:
192+
"""Prepare inputs for the chat generator."""
193+
generator_inputs = {"tools": self.tools}
194+
selected_callback = streaming_callback or self.streaming_callback
195+
if selected_callback is not None:
196+
generator_inputs["streaming_callback"] = selected_callback
197+
return generator_inputs
198+
199+
def _create_agent_span(self) -> Iterator[tracing.Span]:
200+
"""Create a span for the agent run."""
201+
return tracing.tracer.trace(
202+
"haystack.agent.run",
203+
tags={
204+
"haystack.agent.max_steps": self.max_agent_steps,
205+
"haystack.agent.tools": self.tools,
206+
"haystack.agent.exit_conditions": self.exit_conditions,
207+
"haystack.agent.state_schema": self.state_schema,
208+
},
209+
)
210+
190211
def run(
191212
self,
192213
messages: List[ChatMessage],
@@ -205,48 +226,66 @@ def run(
205226
if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
206227
raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.")
207228

208-
state = State(schema=self.state_schema, data=kwargs)
209-
210229
if self.system_prompt is not None:
211230
messages = [ChatMessage.from_system(self.system_prompt)] + messages
231+
232+
input_data = deepcopy({"messages": messages, "streaming_callback": streaming_callback, **kwargs})
233+
234+
state = State(schema=self.state_schema, data=kwargs)
212235
state.set("messages", messages)
213236

214-
generator_inputs: Dict[str, Any] = {"tools": self.tools}
237+
generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback)
238+
239+
component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0)
240+
with self._create_agent_span() as span:
241+
span.set_content_tag("haystack.agent.input", input_data)
242+
counter = 0
243+
while counter < self.max_agent_steps:
244+
# 1. Call the ChatGenerator
245+
llm_messages = Pipeline._run_component(
246+
component_name="chat_generator",
247+
component={"instance": self.chat_generator},
248+
inputs={"messages": messages, **generator_inputs},
249+
component_visits=component_visits,
250+
parent_span=span,
251+
)["replies"]
252+
state.set("messages", llm_messages)
253+
254+
# 2. Check if any of the LLM responses contain a tool call
255+
if not any(msg.tool_call for msg in llm_messages):
256+
counter += 1
257+
break
215258

216-
selected_callback = streaming_callback or self.streaming_callback
217-
if selected_callback is not None:
218-
generator_inputs["streaming_callback"] = selected_callback
259+
# 3. Call the ToolInvoker
260+
# We only send the messages from the LLM to the tool invoker
261+
tool_invoker_result = Pipeline._run_component(
262+
component_name="tool_invoker",
263+
component={"instance": self._tool_invoker},
264+
inputs={"messages": llm_messages, "state": state},
265+
component_visits=component_visits,
266+
parent_span=span,
267+
)
268+
tool_messages = tool_invoker_result["tool_messages"]
269+
state = tool_invoker_result["state"]
270+
state.set("messages", tool_messages)
219271

220-
# Repeat until the exit condition is met
221-
counter = 0
222-
while counter < self.max_agent_steps:
223-
# 1. Call the ChatGenerator
224-
llm_messages = self.chat_generator.run(messages=messages, **generator_inputs)["replies"]
225-
state.set("messages", llm_messages)
226-
227-
# 2. Check if any of the LLM responses contain a tool call
228-
if not any(msg.tool_call for msg in llm_messages):
229-
return {**state.data}
230-
231-
# 3. Call the ToolInvoker
232-
# We only send the messages from the LLM to the tool invoker
233-
tool_invoker_result = self._tool_invoker.run(messages=llm_messages, state=state)
234-
tool_messages = tool_invoker_result["tool_messages"]
235-
state = tool_invoker_result["state"]
236-
state.set("messages", tool_messages)
237-
238-
# 4. Check if any LLM message's tool call name matches an exit condition
239-
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
240-
return {**state.data}
241-
242-
# 5. Fetch the combined messages and send them back to the LLM
243-
messages = state.get("messages")
244-
counter += 1
245-
246-
logger.warning(
247-
"Agent exceeded maximum agent steps of {max_agent_steps}, stopping.", max_agent_steps=self.max_agent_steps
248-
)
249-
return {**state.data}
272+
# 4. Check if any LLM message's tool call name matches an exit condition
273+
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
274+
counter += 1
275+
break
276+
277+
# 5. Fetch the combined messages and send them back to the LLM
278+
messages = state.get("messages")
279+
counter += 1
280+
281+
if counter >= self.max_agent_steps:
282+
logger.warning(
283+
"Agent reached maximum agent steps of {max_agent_steps}, stopping.",
284+
max_agent_steps=self.max_agent_steps,
285+
)
286+
span.set_content_tag("haystack.agent.output", state.data)
287+
span.set_tag("haystack.agent.steps_taken", counter)
288+
return state.data
250289

251290
async def run_async(
252291
self,
@@ -270,66 +309,70 @@ async def run_async(
270309
if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
271310
raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.")
272311

273-
state = State(schema=self.state_schema, data=kwargs)
274-
275312
if self.system_prompt is not None:
276313
messages = [ChatMessage.from_system(self.system_prompt)] + messages
277-
state.set("messages", messages)
278314

279-
generator_inputs: Dict[str, Any] = {"tools": self.tools}
315+
input_data = deepcopy({"messages": messages, "streaming_callback": streaming_callback, **kwargs})
280316

281-
selected_callback = streaming_callback or self.streaming_callback
282-
if selected_callback is not None:
283-
generator_inputs["streaming_callback"] = selected_callback
317+
state = State(schema=self.state_schema, data=kwargs)
318+
state.set("messages", messages)
284319

285-
# Repeat until the exit condition is met
286-
counter = 0
287-
while counter < self.max_agent_steps:
288-
# 1. Call the ChatGenerator
289-
# Check if the chat generator supports async execution
290-
if getattr(self.chat_generator, "__haystack_supports_async__", False):
291-
result = await self.chat_generator.run_async(messages=messages, **generator_inputs) # type: ignore[attr-defined]
292-
llm_messages = result["replies"]
293-
else:
294-
# Fall back to synchronous run if async is not available
295-
loop = asyncio.get_running_loop()
296-
result = await loop.run_in_executor(
297-
None, lambda: self.chat_generator.run(messages=messages, **generator_inputs)
320+
generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback)
321+
322+
component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0)
323+
with self._create_agent_span() as span:
324+
span.set_content_tag("haystack.agent.input", input_data)
325+
counter = 0
326+
while counter < self.max_agent_steps:
327+
# 1. Call the ChatGenerator
328+
result = await AsyncPipeline._run_component_async(
329+
component_name="chat_generator",
330+
component={"instance": self.chat_generator},
331+
component_inputs={"messages": messages, **generator_inputs},
332+
component_visits=component_visits,
333+
max_runs_per_component=self.max_agent_steps,
334+
parent_span=span,
298335
)
299336
llm_messages = result["replies"]
337+
state.set("messages", llm_messages)
338+
339+
# 2. Check if any of the LLM responses contain a tool call
340+
if not any(msg.tool_call for msg in llm_messages):
341+
counter += 1
342+
break
300343

301-
state.set("messages", llm_messages)
302-
303-
# 2. Check if any of the LLM responses contain a tool call
304-
if not any(msg.tool_call for msg in llm_messages):
305-
return {**state.data}
306-
307-
# 3. Call the ToolInvoker
308-
# We only send the messages from the LLM to the tool invoker
309-
# Check if the ToolInvoker supports async execution. Currently, it doesn't.
310-
if getattr(self._tool_invoker, "__haystack_supports_async__", False):
311-
tool_invoker_result = await self._tool_invoker.run_async(messages=llm_messages, state=state) # type: ignore[attr-defined]
312-
else:
313-
loop = asyncio.get_running_loop()
314-
tool_invoker_result = await loop.run_in_executor(
315-
None, lambda: self._tool_invoker.run(messages=llm_messages, state=state)
344+
# 3. Call the ToolInvoker
345+
# We only send the messages from the LLM to the tool invoker
346+
# Check if the ToolInvoker supports async execution. Currently, it doesn't.
347+
tool_invoker_result = await AsyncPipeline._run_component_async(
348+
component_name="tool_invoker",
349+
component={"instance": self._tool_invoker},
350+
component_inputs={"messages": llm_messages, "state": state},
351+
component_visits=component_visits,
352+
max_runs_per_component=self.max_agent_steps,
353+
parent_span=span,
316354
)
317-
tool_messages = tool_invoker_result["tool_messages"]
318-
state = tool_invoker_result["state"]
319-
state.set("messages", tool_messages)
355+
tool_messages = tool_invoker_result["tool_messages"]
356+
state = tool_invoker_result["state"]
357+
state.set("messages", tool_messages)
320358

321-
# 4. Check if any LLM message's tool call name matches an exit condition
322-
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
323-
return {**state.data}
359+
# 4. Check if any LLM message's tool call name matches an exit condition
360+
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
361+
counter += 1
362+
break
324363

325-
# 5. Fetch the combined messages and send them back to the LLM
326-
messages = state.get("messages")
327-
counter += 1
364+
# 5. Fetch the combined messages and send them back to the LLM
365+
messages = state.get("messages")
366+
counter += 1
328367

329-
logger.warning(
330-
"Agent exceeded maximum agent steps of {max_agent_steps}, stopping.", max_agent_steps=self.max_agent_steps
331-
)
332-
return {**state.data}
368+
if counter >= self.max_agent_steps:
369+
logger.warning(
370+
"Agent reached maximum agent steps of {max_agent_steps}, stopping.",
371+
max_agent_steps=self.max_agent_steps,
372+
)
373+
span.set_content_tag("haystack.agent.output", state.data)
374+
span.set_tag("haystack.agent.steps_taken", counter)
375+
return state.data
333376

334377
def _check_exit_conditions(self, llm_messages: List[ChatMessage], tool_messages: List[ChatMessage]) -> bool:
335378
"""

0 commit comments

Comments
 (0)