From 35c6b13c40d1cc9fc3beb058e01d1c0a7aa40171 Mon Sep 17 00:00:00 2001 From: Aaron Longwell Date: Thu, 29 May 2025 10:27:34 -0700 Subject: [PATCH] feat: Add support for AG-UI Adds an AG-UI "bridge" to enable Strands Agents to be the backend for an agentic frontend like Copilot Kit. --- src/strands/__init__.py | 4 +- src/strands/agui/__init__.py | 35 +++ src/strands/agui/bridge.py | 295 ++++++++++++++++++ src/strands/agui/state_tools.py | 196 ++++++++++++ tests/strands/agui/__init__.py | 1 + tests/strands/agui/test_bridge.py | 362 ++++++++++++++++++++++ tests/strands/agui/test_state_tools.py | 402 +++++++++++++++++++++++++ 7 files changed, 1293 insertions(+), 2 deletions(-) create mode 100644 src/strands/agui/__init__.py create mode 100644 src/strands/agui/bridge.py create mode 100644 src/strands/agui/state_tools.py create mode 100644 tests/strands/agui/__init__.py create mode 100644 tests/strands/agui/test_bridge.py create mode 100644 tests/strands/agui/test_state_tools.py diff --git a/src/strands/__init__.py b/src/strands/__init__.py index f4b1228d..9588759d 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -1,8 +1,8 @@ """A framework for building, deploying, and managing AI agents.""" -from . import agent, event_loop, models, telemetry, types +from . import agent, agui, event_loop, models, telemetry, types from .agent.agent import Agent from .tools.decorator import tool from .tools.thread_pool_executor import ThreadPoolExecutorWrapper -__all__ = ["Agent", "ThreadPoolExecutorWrapper", "agent", "event_loop", "models", "tool", "types", "telemetry"] +__all__ = ["Agent", "ThreadPoolExecutorWrapper", "agent", "agui", "event_loop", "models", "tool", "types", "telemetry"] diff --git a/src/strands/agui/__init__.py b/src/strands/agui/__init__.py new file mode 100644 index 00000000..c14a6ed5 --- /dev/null +++ b/src/strands/agui/__init__.py @@ -0,0 +1,35 @@ +"""Strands AG-UI integration package. + +This package provides integration between Strands agents and AG-UI protocol +compatible frontends, including state management and event streaming. +""" + +from .bridge import ( + AGUIEventType, + StrandsAGUIBridge, + StrandsAGUIEndpoint, + create_strands_agui_setup, +) +from .state_tools import ( + StrandsStateManager, + emit_ui_update, + get_agent_state, + get_state_manager, + set_agent_state, + setup_agent_state_management, + update_agent_state, +) + +__all__ = [ + "AGUIEventType", + "StrandsAGUIBridge", + "StrandsAGUIEndpoint", + "create_strands_agui_setup", + "StrandsStateManager", + "emit_ui_update", + "get_agent_state", + "get_state_manager", + "set_agent_state", + "setup_agent_state_management", + "update_agent_state", +] diff --git a/src/strands/agui/bridge.py b/src/strands/agui/bridge.py new file mode 100644 index 00000000..d60631c6 --- /dev/null +++ b/src/strands/agui/bridge.py @@ -0,0 +1,295 @@ +"""Strands to AG-UI Protocol Bridge with State Management. + +This module converts Strands agent events to AG-UI protocol events, +including full state management support for AG-UI compatible frontends. +""" + +import json +import logging +from datetime import datetime +from enum import Enum +from typing import Any, AsyncIterator, Dict, List, Optional +from uuid import uuid4 + +# Use relative import to avoid module name conflict +from .state_tools import StrandsStateManager, get_state_manager, setup_agent_state_management + +logger = logging.getLogger(__name__) + + +class AGUIEventType(str, Enum): + """AG-UI Protocol Event Types.""" + + # Run lifecycle events + RUN_STARTED = "RUN_STARTED" + RUN_FINISHED = "RUN_FINISHED" + RUN_ERROR = "RUN_ERROR" + + # Message events + TEXT_MESSAGE_START = "TEXT_MESSAGE_START" + TEXT_MESSAGE_CONTENT = "TEXT_MESSAGE_CONTENT" + TEXT_MESSAGE_END = "TEXT_MESSAGE_END" + + # Tool events + TOOL_CALL_START = "TOOL_CALL_START" + TOOL_CALL_ARGS = "TOOL_CALL_ARGS" + TOOL_CALL_END = "TOOL_CALL_END" + + # State events - KEY for AG-UI compatibility! + STATE_SNAPSHOT = "STATE_SNAPSHOT" + STATE_DELTA = "STATE_DELTA" + + # Step events + STEP_STARTED = "STEP_STARTED" + STEP_FINISHED = "STEP_FINISHED" + + # Custom events + CUSTOM = "CUSTOM" + RAW = "RAW" + + +class StrandsAGUIBridge: + """Bridge that converts Strands agent events to AG-UI protocol events with state management.""" + + def __init__(self, agent: Any, state_manager: Optional[StrandsStateManager] = None) -> None: + """Initialize the Strands AG-UI bridge. + + Args: + agent: The Strands agent to bridge + state_manager: Optional state manager, will use global one if not provided + """ + self.agent = agent + self.state_manager = state_manager or get_state_manager() + self.current_run_id: Optional[str] = None + self.current_thread_id: Optional[str] = None + self.current_message_id: Optional[str] = None + self.active_tool_calls: Dict[str, Dict[str, Any]] = {} + self.message_started = False + + # Set up state change callback + self.state_manager.add_callback(self._on_state_change) + self._state_change_queue: List[Dict[str, Any]] = [] + + async def stream_agui_events( + self, + prompt: str, + thread_id: Optional[str] = None, + run_id: Optional[str] = None, + initial_state: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + """Stream AG-UI protocol events from a Strands agent execution.""" + self.current_run_id = run_id or str(uuid4()) + self.current_thread_id = thread_id or str(uuid4()) + self.current_message_id = str(uuid4()) + self.message_started = False + + # Set initial state if provided + if initial_state: + self.state_manager.set_state(initial_state) + + try: + # Emit run started event + yield self._create_agui_event( + event_type=AGUIEventType.RUN_STARTED, + data={"thread_id": self.current_thread_id, "run_id": self.current_run_id}, + ) + + # Emit initial state snapshot + current_state = self.state_manager.get_state() + if current_state: + yield self._create_agui_event(event_type=AGUIEventType.STATE_SNAPSHOT, data={"snapshot": current_state}) + + # Stream agent execution and convert events + async for strands_event in self.agent.stream_async(prompt, **kwargs): + agui_events = self._convert_strands_event_to_agui(strands_event) + for agui_event in agui_events: + yield agui_event + + # Emit any queued state changes + while self._state_change_queue: + state_event = self._state_change_queue.pop(0) + yield state_event + + # Emit message end if we started one + if self.message_started: + yield self._create_agui_event( + event_type=AGUIEventType.TEXT_MESSAGE_END, data={"message_id": self.current_message_id} + ) + + # Emit run finished event + yield self._create_agui_event( + event_type=AGUIEventType.RUN_FINISHED, + data={"thread_id": self.current_thread_id, "run_id": self.current_run_id}, + ) + + except Exception as e: + logger.error("Error in Strands-AG-UI bridge: %s", e) + yield self._create_agui_event( + event_type=AGUIEventType.RUN_ERROR, data={"message": str(e), "code": type(e).__name__} + ) + + def _convert_strands_event_to_agui(self, strands_event: Dict[str, Any]) -> List[Dict[str, Any]]: + """Convert a Strands event to one or more AG-UI events.""" + agui_events = [] + + # Handle text content streaming + if "data" in strands_event and strands_event["data"]: + # Start message if not already started + if not self.message_started: + agui_events.append( + self._create_agui_event( + event_type=AGUIEventType.TEXT_MESSAGE_START, + data={"message_id": self.current_message_id, "role": "assistant"}, + ) + ) + self.message_started = True + + # Add content event + agui_events.append( + self._create_agui_event( + event_type=AGUIEventType.TEXT_MESSAGE_CONTENT, + data={"message_id": self.current_message_id, "delta": strands_event["data"]}, + ) + ) + + # Handle tool execution + if "current_tool_use" in strands_event: + tool_use = strands_event["current_tool_use"] + if tool_use and tool_use.get("toolUseId"): + tool_id = tool_use["toolUseId"] + tool_name = tool_use.get("name", "unknown") + + # Track tool start + if tool_id not in self.active_tool_calls: + self.active_tool_calls[tool_id] = {"name": tool_name, "started": True} + agui_events.append( + self._create_agui_event( + event_type=AGUIEventType.TOOL_CALL_START, + data={ + "tool_call_id": tool_id, + "tool_call_name": tool_name, + "parent_message_id": self.current_message_id, + }, + ) + ) + + # Handle tool arguments + if "input" in tool_use: + tool_input = tool_use["input"] + if isinstance(tool_input, dict): + tool_input = json.dumps(tool_input) + elif not isinstance(tool_input, str): + tool_input = str(tool_input) + + agui_events.append( + self._create_agui_event( + event_type=AGUIEventType.TOOL_CALL_ARGS, data={"tool_call_id": tool_id, "delta": tool_input} + ) + ) + + # Handle completion events + if strands_event.get("complete", False): + # End any active tool calls + for tool_id in list(self.active_tool_calls.keys()): + agui_events.append( + self._create_agui_event(event_type=AGUIEventType.TOOL_CALL_END, data={"tool_call_id": tool_id}) + ) + del self.active_tool_calls[tool_id] + + return agui_events + + def _on_state_change(self, new_state: Dict[str, Any], updates: Dict[str, Any]) -> None: + """Handle state changes from the state manager.""" + # Queue state delta event + if updates: + state_event = self._create_agui_event( + event_type=AGUIEventType.STATE_DELTA, data={"delta": self._dict_to_json_patch(updates)} + ) + self._state_change_queue.append(state_event) + + def _dict_to_json_patch(self, updates: Dict[str, Any]) -> List[Dict[str, Any]]: + """Convert dictionary updates to JSON Patch format.""" + patches = [] + for key, value in updates.items(): + if value is None: + patches.append({"op": "remove", "path": f"/{key}"}) + else: + patches.append({"op": "replace", "path": f"/{key}", "value": value}) + return patches + + def _create_agui_event(self, event_type: AGUIEventType, data: Dict[str, Any]) -> Dict[str, Any]: + """Create a properly formatted AG-UI protocol event.""" + return {"type": event_type.value, "timestamp": int(datetime.now().timestamp() * 1000), **data} + + +class StrandsAGUIEndpoint: + """HTTP endpoint that serves AG-UI events from Strands agents with state management.""" + + def __init__(self, agents: Dict[str, Any]) -> None: + """Initialize the Strands AG-UI endpoint. + + Args: + agents: Dictionary of agent name to agent instance + """ + self.agents = agents + self.bridges: Dict[str, StrandsAGUIBridge] = {} + + # Create bridges for each agent + for name, agent in agents.items(): + self.bridges[name] = StrandsAGUIBridge(agent) + + async def handle_request(self, request_data: Dict[str, Any]) -> AsyncIterator[str]: + """Handle AG-UI protocol HTTP request and stream SSE responses.""" + agent_name = request_data.get("agent") + messages = request_data.get("messages", []) + thread_id = request_data.get("threadId") + run_id = request_data.get("runId") + frontend_state = request_data.get("state", {}) + + if not agent_name or agent_name not in self.agents: + yield f"data: {json.dumps({'type': 'RUN_ERROR', 'message': 'Agent not found'})}\n\n" + return + + bridge = self.bridges[agent_name] + + # Extract the latest user message + user_messages = [msg for msg in messages if msg.get("role") == "user"] + if not user_messages: + yield f"data: {json.dumps({'type': 'RUN_ERROR', 'message': 'No user message found'})}\n\n" + return + + latest_message = user_messages[-1] + latest_prompt = latest_message.get("content", "") + if isinstance(latest_prompt, list) and latest_prompt: + latest_prompt = latest_prompt[0].get("text", "") + + try: + async for event in bridge.stream_agui_events( + prompt=latest_prompt, thread_id=thread_id, run_id=run_id, initial_state=frontend_state + ): + yield f"data: {json.dumps(event)}\n\n" + except Exception as e: + logger.error("Error in AG-UI endpoint: %s", e) + error_event = {"type": "RUN_ERROR", "message": str(e), "code": type(e).__name__} + yield f"data: {json.dumps(error_event)}\n\n" + + +def create_strands_agui_setup( + agents: Dict[str, Any], initial_states: Optional[Dict[str, Dict[str, Any]]] = None +) -> StrandsAGUIEndpoint: + """Create a complete Strands + AG-UI setup with state management. + + Args: + agents: Dictionary mapping agent names to agent instances + initial_states: Optional dictionary mapping agent names to their initial states + + Returns: + A configured StrandsAGUIEndpoint instance + """ + # Set up state management for each agent + for name, agent in agents.items(): + initial_state = initial_states.get(name, {}) if initial_states else {} + setup_agent_state_management(agent, initial_state) + + return StrandsAGUIEndpoint(agents) diff --git a/src/strands/agui/state_tools.py b/src/strands/agui/state_tools.py new file mode 100644 index 00000000..e1a5db38 --- /dev/null +++ b/src/strands/agui/state_tools.py @@ -0,0 +1,196 @@ +"""Strands State Management Tools for AG-UI Compatibility. + +This module provides state management capabilities to Strands agents +through a tool-based approach, making them compatible with AG-UI frontends +that expect state synchronization. +""" + +import json +import threading +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, TypeVar, cast + +# Use relative imports to avoid module name conflicts +from ..tools.decorator import tool +from ..types.tools import ToolResult, ToolUse + +T = TypeVar("T") + + +class StrandsStateManager: + """Thread-safe state manager for Strands agents.""" + + def __init__(self) -> None: + """Initialize the state manager with empty state.""" + self._state: Dict[str, Any] = {} + self._lock = threading.RLock() + self._callbacks: List[Callable[[Dict[str, Any], Dict[str, Any]], None]] = [] + + def get_state(self) -> Dict[str, Any]: + """Get current state snapshot.""" + with self._lock: + return self._state.copy() + + def update_state(self, updates: Dict[str, Any]) -> Dict[str, Any]: + """Update state and return the new state.""" + with self._lock: + self._state.update(updates) + new_state = self._state.copy() + + # Notify callbacks of state change + for callback in self._callbacks: + callback(new_state, updates) + + return new_state + + def set_state(self, new_state: Dict[str, Any]) -> Dict[str, Any]: + """Replace entire state.""" + with self._lock: + old_state = self._state.copy() + self._state = new_state.copy() + + # Calculate delta for callbacks + delta = self._calculate_delta(old_state, new_state) + for callback in self._callbacks: + callback(new_state, delta) + + return new_state + + def add_callback(self, callback: Any) -> None: + """Add callback for state changes.""" + self._callbacks.append(callback) + + def _calculate_delta(self, old_state: Dict, new_state: Dict) -> Dict[str, Any]: + """Calculate state delta between old and new state.""" + delta = {} + + # Find changed/added keys + for key, value in new_state.items(): + if key not in old_state or old_state[key] != value: + delta[key] = value + + # Find removed keys + for key in old_state: + if key not in new_state: + delta[key] = None + + return delta + + +# Global state manager instance +_state_manager = StrandsStateManager() + + +def get_state_manager() -> StrandsStateManager: + """Get the global state manager instance.""" + return _state_manager + + +@tool +def get_agent_state() -> Dict[str, Any]: + """Get the current agent state. + + This tool allows the agent to read its current state, + which is synchronized with the frontend. + + Returns: + Dictionary containing the current agent state + """ + state = _state_manager.get_state() + return {"status": "success", "content": [{"text": f"Current agent state: {json.dumps(state, indent=2)}"}]} + + +@tool +def update_agent_state(updates: Dict[str, Any]) -> Dict[str, Any]: + """Update specific keys in the agent state. + + This tool allows the agent to update its state, which will be + synchronized with the frontend and trigger UI updates. + + Args: + updates: Dictionary of state updates to apply + + Returns: + Dictionary with the updated state + """ + try: + _state_manager.update_state(updates) + return {"status": "success", "content": [{"text": f"Updated state with: {json.dumps(updates, indent=2)}"}]} + except Exception as e: + return {"status": "error", "content": [{"text": f"Failed to update state: {str(e)}"}]} + + +@tool +def set_agent_state(new_state: Dict[str, Any]) -> Dict[str, Any]: + """Replace the entire agent state. + + This tool allows the agent to completely replace its state, + which will be synchronized with the frontend. + + Args: + new_state: Complete new state to set + + Returns: + Dictionary with the new state + """ + try: + _state_manager.set_state(new_state) + return {"status": "success", "content": [{"text": f"Set new state: {json.dumps(new_state, indent=2)}"}]} + except Exception as e: + return {"status": "error", "content": [{"text": f"Failed to set state: {str(e)}"}]} + + +@tool +def emit_ui_update(component_name: str, props: Dict[str, Any]) -> Dict[str, Any]: + """Emit a UI update event for a specific component. + + This tool allows the agent to trigger specific UI updates + by sending component props to the frontend. + + Args: + component_name: Name of the UI component to update + props: Properties/data to send to the component + + Returns: + Confirmation of the UI update emission + """ + try: + # Update state with UI-specific data + ui_updates = { + f"ui_{component_name}": props, + "last_ui_update": {"component": component_name, "timestamp": datetime.now().isoformat(), "props": props}, + } + + _state_manager.update_state(ui_updates) + + return {"status": "success", "content": [{"text": f"Emitted UI update for {component_name}"}]} + except Exception as e: + return {"status": "error", "content": [{"text": f"Failed to emit UI update: {str(e)}"}]} + + +def setup_agent_state_management(agent: Any, initial_state: Optional[Dict[str, Any]] = None) -> StrandsStateManager: + """Set up state management for a Strands agent. + + Args: + agent: The Strands agent instance + initial_state: Optional initial state to set + + Returns: + The configured state manager + """ + # Import here to avoid circular imports + from ..tools.tools import FunctionTool + + # Add state management tools to agent + state_tools = [get_agent_state, update_agent_state, set_agent_state, emit_ui_update] + + # Add tools to agent's tool registry using FunctionTool wrapper + for tool_func in state_tools: + function_tool = FunctionTool(cast(Callable[[ToolUse], ToolResult], tool_func)) + agent.tool_registry.register_tool(function_tool) + + # Set initial state if provided + if initial_state: + _state_manager.set_state(initial_state) + + return _state_manager diff --git a/tests/strands/agui/__init__.py b/tests/strands/agui/__init__.py new file mode 100644 index 00000000..f6b92d7d --- /dev/null +++ b/tests/strands/agui/__init__.py @@ -0,0 +1 @@ +"""Tests for the AGUI integration module.""" diff --git a/tests/strands/agui/test_bridge.py b/tests/strands/agui/test_bridge.py new file mode 100644 index 00000000..919d6f0f --- /dev/null +++ b/tests/strands/agui/test_bridge.py @@ -0,0 +1,362 @@ +"""Tests for AGUI bridge functionality.""" + +import json +from unittest.mock import MagicMock, Mock + +import pytest + +from strands.agui.bridge import ( + AGUIEventType, + StrandsAGUIBridge, + StrandsAGUIEndpoint, + create_strands_agui_setup, +) +from strands.agui.state_tools import StrandsStateManager + + +class TestAGUIEventType: + """Test AGUI event type enumeration.""" + + def test_event_types_exist(self): + """Test that all expected event types are defined.""" + expected_events = { + "RUN_STARTED", + "RUN_FINISHED", + "RUN_ERROR", + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_END", + "TOOL_CALL_START", + "TOOL_CALL_ARGS", + "TOOL_CALL_END", + "STATE_SNAPSHOT", + "STATE_DELTA", + "STEP_STARTED", + "STEP_FINISHED", + "CUSTOM", + "RAW", + } + + for event_name in expected_events: + assert hasattr(AGUIEventType, event_name) + + def test_event_type_values(self): + """Test that event types have correct string values.""" + assert AGUIEventType.RUN_STARTED == "RUN_STARTED" + assert AGUIEventType.STATE_SNAPSHOT == "STATE_SNAPSHOT" + assert AGUIEventType.STATE_DELTA == "STATE_DELTA" + assert AGUIEventType.TOOL_CALL_START == "TOOL_CALL_START" + + def test_enum_inheritance(self): + """Test that AGUIEventType properly inherits from str and Enum.""" + event = AGUIEventType.RUN_STARTED + assert isinstance(event, str) + assert event.value == "RUN_STARTED" + + +class TestStrandsAGUIBridge: + """Test the StrandsAGUIBridge class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_agent = MagicMock() + self.mock_state_manager = Mock(spec=StrandsStateManager) + self.mock_state_manager.get_state.return_value = {} + self.bridge = StrandsAGUIBridge(self.mock_agent, self.mock_state_manager) + + def test_initialization(self): + """Test bridge initialization.""" + assert self.bridge.agent is self.mock_agent + assert self.bridge.state_manager is self.mock_state_manager + assert self.bridge.current_run_id is None + assert self.bridge.current_thread_id is None + assert self.bridge.current_message_id is None + assert self.bridge.active_tool_calls == {} + assert self.bridge.message_started is False + + def test_initialization_default_state_manager(self): + """Test initialization without explicit state manager.""" + bridge = StrandsAGUIBridge(self.mock_agent) + + assert bridge.agent is self.mock_agent + assert isinstance(bridge.state_manager, StrandsStateManager) + + def test_convert_strands_event_to_agui_text_content(self): + """Test converting text content events.""" + strands_event = {"data": "Hello world"} + + agui_events = self.bridge._convert_strands_event_to_agui(strands_event) + + assert len(agui_events) == 2 + # First should be message start + assert agui_events[0]["type"] == AGUIEventType.TEXT_MESSAGE_START + # Second should be content + assert agui_events[1]["type"] == AGUIEventType.TEXT_MESSAGE_CONTENT + assert agui_events[1]["delta"] == "Hello world" + + def test_convert_strands_event_to_agui_tool_call(self): + """Test converting tool call events.""" + strands_event = { + "current_tool_use": {"toolUseId": "call_123", "name": "test_tool", "input": {"param": "value"}} + } + + agui_events = self.bridge._convert_strands_event_to_agui(strands_event) + + # Should have tool call start and args events + tool_events = [ + e for e in agui_events if e["type"] in [AGUIEventType.TOOL_CALL_START, AGUIEventType.TOOL_CALL_ARGS] + ] + assert len(tool_events) >= 2 + + def test_convert_strands_event_to_agui_completion(self): + """Test converting completion events.""" + # First add a tool call to track + self.bridge.active_tool_calls["call_123"] = {"name": "test_tool"} + + strands_event = {"complete": True} + agui_events = self.bridge._convert_strands_event_to_agui(strands_event) + + # Should have tool call end event + end_events = [e for e in agui_events if e["type"] == AGUIEventType.TOOL_CALL_END] + assert len(end_events) == 1 + + def test_on_state_change(self): + """Test state change callback handling.""" + new_state = {"key": "new_value"} + updates = {"key": "new_value"} + + self.bridge._on_state_change(new_state, updates) + + # Should queue state change event + assert len(self.bridge._state_change_queue) == 1 + event = self.bridge._state_change_queue[0] + assert event["type"] == AGUIEventType.STATE_DELTA + + def test_dict_to_json_patch(self): + """Test dictionary to JSON patch conversion.""" + updates = {"new_key": "new_value", "updated_key": "updated_value", "removed_key": None} + + patches = self.bridge._dict_to_json_patch(updates) + + assert len(patches) == 3 + + # Check patch operations + ops = {patch["op"] for patch in patches} + assert "replace" in ops + assert "remove" in ops + + def test_create_agui_event(self): + """Test AGUI event creation.""" + event_type = AGUIEventType.RUN_STARTED + data = {"thread_id": "test_thread"} + + event = self.bridge._create_agui_event(event_type, data) + + assert event["type"] == event_type.value + assert event["thread_id"] == "test_thread" + assert "timestamp" in event + assert isinstance(event["timestamp"], int) + + def test_convert_strands_event_empty(self): + """Test converting empty strands event.""" + strands_event = {} + + agui_events = self.bridge._convert_strands_event_to_agui(strands_event) + + # Should return empty list for empty event + assert agui_events == [] + + def test_message_state_tracking(self): + """Test that message state is tracked correctly.""" + # Initially no message started + assert self.bridge.message_started is False + + # Process text content event + strands_event = {"data": "Hello"} + self.bridge._convert_strands_event_to_agui(strands_event) + + # Should have started message + assert self.bridge.message_started is True + + # Second text event should not start new message + strands_event2 = {"data": " world"} + agui_events2 = self.bridge._convert_strands_event_to_agui(strands_event2) + + # Should only have content event, not start event + start_events = [e for e in agui_events2 if e["type"] == AGUIEventType.TEXT_MESSAGE_START] + assert len(start_events) == 0 + + +class TestStrandsAGUIEndpoint: + """Test the StrandsAGUIEndpoint class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_agent = MagicMock() + self.agents = {"test_agent": self.mock_agent} + self.endpoint = StrandsAGUIEndpoint(self.agents) + + def test_initialization(self): + """Test endpoint initialization.""" + assert self.endpoint.agents == self.agents + assert "test_agent" in self.endpoint.bridges + assert isinstance(self.endpoint.bridges["test_agent"], StrandsAGUIBridge) + + @pytest.mark.asyncio + async def test_handle_request_agent_not_found(self): + """Test handling request for non-existent agent.""" + request_data = {"agent": "nonexistent_agent", "messages": [{"role": "user", "content": "Hello"}]} + + responses = [] + async for response in self.endpoint.handle_request(request_data): + responses.append(response) + + assert len(responses) == 1 + response_data = json.loads(responses[0][6:-2]) # Parse SSE data + assert response_data["type"] == "RUN_ERROR" + assert "Agent not found" in response_data["message"] + + @pytest.mark.asyncio + async def test_handle_request_no_user_message(self): + """Test handling request without user message.""" + request_data = {"agent": "test_agent", "messages": []} + + responses = [] + async for response in self.endpoint.handle_request(request_data): + responses.append(response) + + assert len(responses) == 1 + response_data = json.loads(responses[0][6:-2]) + assert response_data["type"] == "RUN_ERROR" + assert "No user message found" in response_data["message"] + + def test_message_content_extraction(self): + """Test extraction of message content from various formats.""" + # Test simple string content + messages = [{"role": "user", "content": "Simple message"}] + user_messages = [msg for msg in messages if msg.get("role") == "user"] + content = user_messages[-1].get("content", "") + assert content == "Simple message" + + # Test array content + messages2 = [{"role": "user", "content": [{"text": "Array message"}]}] + user_messages2 = [msg for msg in messages2 if msg.get("role") == "user"] + content2 = user_messages2[-1].get("content", "") + if isinstance(content2, list) and content2: + content2 = content2[0].get("text", "") + assert content2 == "Array message" + + +class TestCreateStrandsAGUISetup: + """Test the create_strands_agui_setup function.""" + + def test_create_setup_basic(self): + """Test basic setup creation.""" + mock_agent = MagicMock() + mock_agent.tool_registry = MagicMock() + agents = {"test_agent": mock_agent} + + endpoint = create_strands_agui_setup(agents) + + assert isinstance(endpoint, StrandsAGUIEndpoint) + assert endpoint.agents == agents + assert "test_agent" in endpoint.bridges + + def test_create_setup_with_initial_states(self): + """Test setup with initial states.""" + mock_agent = MagicMock() + mock_agent.tool_registry = MagicMock() + agents = {"test_agent": mock_agent} + initial_states = {"test_agent": {"initial": "value"}} + + endpoint = create_strands_agui_setup(agents, initial_states) + + assert isinstance(endpoint, StrandsAGUIEndpoint) + # Should have called setup_agent_state_management for each agent + mock_agent.tool_registry.register_tool.assert_called() + + def test_create_setup_multiple_agents(self): + """Test setup with multiple agents.""" + mock_agent1 = MagicMock() + mock_agent1.tool_registry = MagicMock() + mock_agent2 = MagicMock() + mock_agent2.tool_registry = MagicMock() + + agents = {"agent1": mock_agent1, "agent2": mock_agent2} + + endpoint = create_strands_agui_setup(agents) + + assert len(endpoint.bridges) == 2 + assert "agent1" in endpoint.bridges + assert "agent2" in endpoint.bridges + + +class TestIntegration: + """Integration tests for the complete AGUI bridge system.""" + + def test_state_synchronization(self): + """Test state synchronization between bridge and state manager.""" + mock_agent = MagicMock() + bridge = StrandsAGUIBridge(mock_agent) + + # Verify callback was registered + assert len(bridge.state_manager._callbacks) > 0 + + # Trigger state change + bridge.state_manager.update_state({"test": "value"}) + + # Should have queued state event + assert len(bridge._state_change_queue) > 0 + + # Event should be STATE_DELTA type + event = bridge._state_change_queue[0] + assert event["type"] == AGUIEventType.STATE_DELTA + + def test_event_timestamp_consistency(self): + """Test that events have consistent timestamp format.""" + mock_agent = MagicMock() + bridge = StrandsAGUIBridge(mock_agent) + + # Create multiple events + events = [ + bridge._create_agui_event(AGUIEventType.RUN_STARTED, {}), + bridge._create_agui_event(AGUIEventType.STATE_SNAPSHOT, {"snapshot": {}}), + bridge._create_agui_event(AGUIEventType.RUN_FINISHED, {}), + ] + + # All should have timestamps + for event in events: + assert "timestamp" in event + assert isinstance(event["timestamp"], int) + # Should be reasonable timestamp (Unix milliseconds) + assert event["timestamp"] > 1000000000000 + + def test_json_patch_format(self): + """Test JSON patch format consistency.""" + mock_agent = MagicMock() + bridge = StrandsAGUIBridge(mock_agent) + + updates = {"key1": "value1", "key2": None, "key3": "value3"} + patches = bridge._dict_to_json_patch(updates) + + # Should have proper JSON patch structure + for patch in patches: + assert "op" in patch + assert "path" in patch + assert patch["op"] in ["replace", "remove"] + assert patch["path"].startswith("/") + + def test_bridge_endpoint_integration(self): + """Test integration between bridge and endpoint.""" + mock_agent = MagicMock() + mock_agent.tool_registry = MagicMock() + agents = {"test_agent": mock_agent} + + endpoint = create_strands_agui_setup(agents) + bridge = endpoint.bridges["test_agent"] + + # Bridge should have same agent + assert bridge.agent is mock_agent + + # Bridge should be in endpoint + assert endpoint.bridges["test_agent"] is bridge diff --git a/tests/strands/agui/test_state_tools.py b/tests/strands/agui/test_state_tools.py new file mode 100644 index 00000000..69619c29 --- /dev/null +++ b/tests/strands/agui/test_state_tools.py @@ -0,0 +1,402 @@ +"""Tests for AGUI state management tools.""" + +import json +import threading +import time +from typing import Any, Dict +from unittest.mock import MagicMock, patch + +from strands.agui.state_tools import ( + StrandsStateManager, + emit_ui_update, + get_agent_state, + get_state_manager, + set_agent_state, + setup_agent_state_management, + update_agent_state, +) +from strands.tools.tools import FunctionTool + + +class TestStrandsStateManager: + """Test the StrandsStateManager class.""" + + def test_initialization(self): + """Test that StrandsStateManager initializes correctly.""" + manager = StrandsStateManager() + + assert manager._state == {} + assert manager._callbacks == [] + assert manager._lock is not None + + def test_get_state_empty(self): + """Test getting state when it's empty.""" + manager = StrandsStateManager() + state = manager.get_state() + + assert state == {} + + def test_update_state_basic(self): + """Test basic state update functionality.""" + manager = StrandsStateManager() + updates = {"key1": "value1", "counter": 42} + + new_state = manager.update_state(updates) + + assert new_state == updates + assert manager.get_state() == updates + + def test_update_state_multiple(self): + """Test multiple state updates.""" + manager = StrandsStateManager() + + manager.update_state({"key1": "value1"}) + manager.update_state({"key2": "value2"}) + manager.update_state({"key1": "updated_value1"}) + + final_state = manager.get_state() + expected = {"key1": "updated_value1", "key2": "value2"} + + assert final_state == expected + + def test_set_state_replace(self): + """Test that set_state completely replaces the state.""" + manager = StrandsStateManager() + + # Set initial state + manager.update_state({"old_key": "old_value"}) + + # Replace with new state + new_state = {"new_key": "new_value"} + result_state = manager.set_state(new_state) + + assert result_state == new_state + assert manager.get_state() == new_state + assert "old_key" not in manager.get_state() + + def test_state_isolation(self): + """Test that returned state is isolated from internal state.""" + manager = StrandsStateManager() + manager.update_state({"key": "value"}) + + state1 = manager.get_state() + state2 = manager.get_state() + + # Modify one copy + state1["key"] = "modified" + + # Other copy should be unchanged + assert state2["key"] == "value" + assert manager.get_state()["key"] == "value" + + def test_callback_system(self): + """Test state change callback system.""" + manager = StrandsStateManager() + callback_calls = [] + + def test_callback(new_state: Dict[str, Any], updates: Dict[str, Any]) -> None: + callback_calls.append((new_state.copy(), updates.copy())) + + manager.add_callback(test_callback) + manager.update_state({"key": "value"}) + + assert len(callback_calls) == 1 + new_state, updates = callback_calls[0] + assert new_state == {"key": "value"} + assert updates == {"key": "value"} + + def test_multiple_callbacks(self): + """Test that multiple callbacks are triggered.""" + manager = StrandsStateManager() + calls1 = [] + calls2 = [] + + def callback1(new_state, updates): + calls1.append((new_state, updates)) + + def callback2(new_state, updates): + calls2.append((new_state, updates)) + + manager.add_callback(callback1) + manager.add_callback(callback2) + manager.update_state({"test": "value"}) + + assert len(calls1) == 1 + assert len(calls2) == 1 + + def test_calculate_delta(self): + """Test state delta calculation.""" + manager = StrandsStateManager() + + old_state = {"key1": "value1", "key2": "value2"} + new_state = {"key1": "updated_value1", "key3": "value3"} + + delta = manager._calculate_delta(old_state, new_state) + + # Delta should include removed keys as None and new/updated keys + expected = {"key1": "updated_value1", "key3": "value3", "key2": None} + assert delta == expected + + def test_thread_safety(self): + """Test that the state manager is thread-safe.""" + manager = StrandsStateManager() + results = [] + errors = [] + + def worker(worker_id: int) -> None: + try: + for i in range(10): + key = f"worker_{worker_id}_item_{i}" + manager.update_state({key: f"value_{i}"}) + time.sleep(0.001) # Small delay to increase chance of race conditions + results.append(worker_id) + except Exception as e: + errors.append(f"Worker {worker_id} error: {e}") + + # Start multiple worker threads + threads = [] + for i in range(3): + thread = threading.Thread(target=worker, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Check results + assert len(errors) == 0, f"Thread safety errors: {errors}" + assert len(results) == 3 + + final_state = manager.get_state() + # Should have 3 workers × 10 items each = 30 keys + assert len(final_state) == 30 + + +class TestGlobalStateManager: + """Test the global state manager functions.""" + + def test_get_state_manager_singleton(self): + """Test that get_state_manager returns the same instance.""" + manager1 = get_state_manager() + manager2 = get_state_manager() + + assert manager1 is manager2 + assert isinstance(manager1, StrandsStateManager) + + def test_state_persistence_across_calls(self): + """Test that state persists across different calls.""" + manager1 = get_state_manager() + manager1.update_state({"persistent": "value"}) + + manager2 = get_state_manager() + state = manager2.get_state() + + assert state["persistent"] == "value" + + +class TestStateTool: + """Test the state management tool functions.""" + + def test_get_agent_state_success(self): + """Test successful state retrieval.""" + # Set up known state + manager = get_state_manager() + manager.set_state({"test_key": "test_value"}) + + result = get_agent_state() + + assert result["status"] == "success" + assert len(result["content"]) == 1 + assert "test_key" in result["content"][0]["text"] + assert "test_value" in result["content"][0]["text"] + + def test_update_agent_state_success(self): + """Test successful state update.""" + updates = {"new_key": "new_value", "counter": 100} + + result = update_agent_state(updates) + + assert result["status"] == "success" + assert "Updated state with" in result["content"][0]["text"] + + # Verify state was actually updated + manager = get_state_manager() + state = manager.get_state() + assert state["new_key"] == "new_value" + assert state["counter"] == 100 + + def test_set_agent_state_success(self): + """Test successful state replacement.""" + # Set initial state + get_state_manager().update_state({"old_key": "old_value"}) + + new_state = {"replaced_key": "replaced_value"} + result = set_agent_state(new_state) + + assert result["status"] == "success" + assert "Set new state" in result["content"][0]["text"] + + # Verify state was replaced + manager = get_state_manager() + current_state = manager.get_state() + assert current_state == new_state + assert "old_key" not in current_state + + def test_emit_ui_update_success(self): + """Test successful UI update emission.""" + component_name = "GameBoard" + props = {"score": 100, "level": 5} + + result = emit_ui_update(component_name, props) + + assert result["status"] == "success" + assert f"Emitted UI update for {component_name}" in result["content"][0]["text"] + + # Verify UI update was stored in state - checking actual implementation format + manager = get_state_manager() + state = manager.get_state() + # Based on implementation, it stores as ui_{component_name} and last_ui_update + assert f"ui_{component_name}" in state + assert state[f"ui_{component_name}"] == props + assert "last_ui_update" in state + + @patch("strands.agui.state_tools._state_manager") + def test_update_agent_state_error_handling(self, mock_manager): + """Test error handling in update_agent_state.""" + mock_manager.update_state.side_effect = Exception("Test error") + + result = update_agent_state({"test": "value"}) + + assert result["status"] == "error" + assert "Test error" in result["content"][0]["text"] + + @patch("strands.agui.state_tools._state_manager") + def test_set_agent_state_error_handling(self, mock_manager): + """Test error handling in set_agent_state.""" + mock_manager.set_state.side_effect = Exception("Set error") + + result = set_agent_state({"test": "value"}) + + assert result["status"] == "error" + assert "Set error" in result["content"][0]["text"] + + @patch("strands.agui.state_tools._state_manager") + def test_emit_ui_update_error_handling(self, mock_manager): + """Test error handling in emit_ui_update.""" + mock_manager.update_state.side_effect = Exception("UI error") + + result = emit_ui_update("TestComponent", {"prop": "value"}) + + assert result["status"] == "error" + assert "UI error" in result["content"][0]["text"] + + +class TestAgentSetup: + """Test agent state management setup.""" + + def test_setup_agent_state_management(self): + """Test setting up state management for an agent.""" + # Mock agent with tool registry + mock_tool_registry = MagicMock() + mock_agent = MagicMock() + mock_agent.tool_registry = mock_tool_registry + + initial_state = {"initial_key": "initial_value"} + + result_manager = setup_agent_state_management(mock_agent, initial_state) + + # Should return the state manager + assert isinstance(result_manager, StrandsStateManager) + + # Should have registered 4 tools + assert mock_tool_registry.register_tool.call_count == 4 + + # Verify tools are FunctionTool instances + for call in mock_tool_registry.register_tool.call_args_list: + tool = call[0][0] + assert isinstance(tool, FunctionTool) + + # Verify initial state was set + current_state = result_manager.get_state() + assert current_state["initial_key"] == "initial_value" + + def test_setup_agent_state_management_no_initial_state(self): + """Test setup without initial state.""" + mock_tool_registry = MagicMock() + mock_agent = MagicMock() + mock_agent.tool_registry = mock_tool_registry + + result_manager = setup_agent_state_management(mock_agent) + + assert isinstance(result_manager, StrandsStateManager) + assert mock_tool_registry.register_tool.call_count == 4 + + def test_tool_names_registered(self): + """Test that correct tool names are registered.""" + mock_tool_registry = MagicMock() + mock_agent = MagicMock() + mock_agent.tool_registry = mock_tool_registry + + setup_agent_state_management(mock_agent) + + # Extract tool names from registered tools + registered_tools = [] + for call in mock_tool_registry.register_tool.call_args_list: + tool = call[0][0] + # Use _name attribute which is the actual attribute in FunctionTool + registered_tools.append(tool._name) + + expected_tools = {"get_agent_state", "update_agent_state", "set_agent_state", "emit_ui_update"} + + assert set(registered_tools) == expected_tools + + +class TestIntegration: + """Integration tests for the complete state management system.""" + + def test_state_callback_integration(self): + """Test integration between state updates and callbacks.""" + manager = get_state_manager() + manager.set_state({}) # Clear state + + callback_events = [] + + def test_callback(new_state, updates): + callback_events.append({"new_state": new_state.copy(), "updates": updates.copy()}) + + manager.add_callback(test_callback) + + # Perform operations using tools + update_agent_state({"key1": "value1"}) + set_agent_state({"key2": "value2"}) + emit_ui_update("TestComponent", {"prop": "value"}) + + # Should have triggered callbacks + assert len(callback_events) >= 3 + + # Final state should contain the UI update based on actual implementation + final_state = manager.get_state() + assert "ui_TestComponent" in final_state # Actual format from implementation + assert "TestComponent" in str(final_state) # Component name should be in the state somehow + + def test_json_serialization(self): + """Test that state can be JSON serialized.""" + manager = get_state_manager() + manager.set_state( + { + "string": "value", + "number": 42, + "boolean": True, + "null": None, + "list": [1, 2, 3], + "dict": {"nested": "value"}, + } + ) + + state = manager.get_state() + json_str = json.dumps(state) + + # Should not raise exception + parsed = json.loads(json_str) + assert parsed == state