Skip to content

Commit 9ce8f3d

Browse files
awsarronSourabh SarupriaSourabh Sarupria
authored
Fix agent default callback handler (#170)
Co-authored-by: Sourabh Sarupria <rob.sarupria@mac.chi.chicorp> Co-authored-by: Sourabh Sarupria <rob.sarupria@e5865.x.akamaiedge.net>
1 parent da55dc8 commit 9ce8f3d

File tree

2 files changed

+60
-6
lines changed

2 files changed

+60
-6
lines changed

src/strands/agent/agent.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@
4444
logger = logging.getLogger(__name__)
4545

4646

47+
# Sentinel class and object to distinguish between explicit None and default parameter value
48+
class _DefaultCallbackHandlerSentinel:
49+
"""Sentinel class to distinguish between explicit None and default parameter value."""
50+
51+
pass
52+
53+
54+
_DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel()
55+
56+
4757
class Agent:
4858
"""Core Agent interface.
4959
@@ -70,7 +80,7 @@ def __init__(self, agent: "Agent") -> None:
7080
# agent tools and thus break their execution.
7181
self._agent = agent
7282

73-
def __getattr__(self, name: str) -> Callable:
83+
def __getattr__(self, name: str) -> Callable[..., Any]:
7484
"""Call tool as a function.
7585
7686
This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
@@ -177,7 +187,9 @@ def __init__(
177187
messages: Optional[Messages] = None,
178188
tools: Optional[List[Union[str, Dict[str, str], Any]]] = None,
179189
system_prompt: Optional[str] = None,
180-
callback_handler: Optional[Callable] = PrintingCallbackHandler(),
190+
callback_handler: Optional[
191+
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel]
192+
] = _DEFAULT_CALLBACK_HANDLER,
181193
conversation_manager: Optional[ConversationManager] = None,
182194
max_parallel_tools: int = os.cpu_count() or 1,
183195
record_direct_tool_call: bool = True,
@@ -204,7 +216,8 @@ def __init__(
204216
system_prompt: System prompt to guide model behavior.
205217
If None, the model will behave according to its default settings.
206218
callback_handler: Callback for processing events as they happen during agent execution.
207-
Defaults to strands.handlers.PrintingCallbackHandler if None.
219+
If not provided (using the default), a new PrintingCallbackHandler instance is created.
220+
If explicitly set to None, null_callback_handler is used.
208221
conversation_manager: Manager for conversation history and context window.
209222
Defaults to strands.agent.conversation_manager.SlidingWindowConversationManager if None.
210223
max_parallel_tools: Maximum number of tools to run in parallel when the model returns multiple tool calls.
@@ -222,7 +235,17 @@ def __init__(
222235
self.messages = messages if messages is not None else []
223236

224237
self.system_prompt = system_prompt
225-
self.callback_handler = callback_handler or null_callback_handler
238+
239+
# If not provided, create a new PrintingCallbackHandler instance
240+
# If explicitly set to None, use null_callback_handler
241+
# Otherwise use the passed callback_handler
242+
self.callback_handler: Union[Callable[..., Any], PrintingCallbackHandler]
243+
if isinstance(callback_handler, _DefaultCallbackHandlerSentinel):
244+
self.callback_handler = PrintingCallbackHandler()
245+
elif callback_handler is None:
246+
self.callback_handler = null_callback_handler
247+
else:
248+
self.callback_handler = callback_handler
226249

227250
self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager()
228251

@@ -415,7 +438,7 @@ def target_callback() -> None:
415438
thread.join()
416439

417440
def _run_loop(
418-
self, prompt: str, kwargs: Any, supplementary_callback_handler: Optional[Callable] = None
441+
self, prompt: str, kwargs: Dict[str, Any], supplementary_callback_handler: Optional[Callable[..., Any]] = None
419442
) -> AgentResult:
420443
"""Execute the agent's event loop with the given prompt and parameters."""
421444
try:
@@ -441,7 +464,7 @@ def _run_loop(
441464
finally:
442465
self.conversation_manager.apply_management(self)
443466

444-
def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str, Any]) -> AgentResult:
467+
def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: Dict[str, Any]) -> AgentResult:
445468
"""Execute the event loop cycle with retry logic for context window limits.
446469
447470
This internal method handles the execution of the event loop cycle and implements

tests/strands/agent/test_agent.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,37 @@ def test_agent_with_callback_handler_none_uses_null_handler():
686686
assert agent.callback_handler == null_callback_handler
687687

688688

689+
def test_agent_callback_handler_not_provided_creates_new_instances():
690+
"""Test that when callback_handler is not provided, new PrintingCallbackHandler instances are created."""
691+
# Create two agents without providing callback_handler
692+
agent1 = Agent()
693+
agent2 = Agent()
694+
695+
# Both should have PrintingCallbackHandler instances
696+
assert isinstance(agent1.callback_handler, PrintingCallbackHandler)
697+
assert isinstance(agent2.callback_handler, PrintingCallbackHandler)
698+
699+
# But they should be different object instances
700+
assert agent1.callback_handler is not agent2.callback_handler
701+
702+
703+
def test_agent_callback_handler_explicit_none_uses_null_handler():
704+
"""Test that when callback_handler is explicitly set to None, null_callback_handler is used."""
705+
agent = Agent(callback_handler=None)
706+
707+
# Should use null_callback_handler
708+
assert agent.callback_handler is null_callback_handler
709+
710+
711+
def test_agent_callback_handler_custom_handler_used():
712+
"""Test that when a custom callback_handler is provided, it is used."""
713+
custom_handler = unittest.mock.Mock()
714+
agent = Agent(callback_handler=custom_handler)
715+
716+
# Should use the provided custom handler
717+
assert agent.callback_handler is custom_handler
718+
719+
689720
@pytest.mark.asyncio
690721
async def test_stream_async_returns_all_events(mock_event_loop_cycle):
691722
agent = Agent()

0 commit comments

Comments
 (0)