-
Notifications
You must be signed in to change notification settings - Fork 130
feat: Add support for AG-UI #143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
adlio
wants to merge
1
commit into
strands-agents:main
Choose a base branch
from
adlio:agui
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
"""A framework for building, deploying, and managing AI agents.""" | ||
|
||
from . import agent, event_loop, models, telemetry, types | ||
from . import agent, agui, event_loop, models, telemetry, types | ||
from .agent.agent import Agent | ||
from .tools.decorator import tool | ||
from .tools.thread_pool_executor import ThreadPoolExecutorWrapper | ||
|
||
__all__ = ["Agent", "ThreadPoolExecutorWrapper", "agent", "event_loop", "models", "tool", "types", "telemetry"] | ||
__all__ = ["Agent", "ThreadPoolExecutorWrapper", "agent", "agui", "event_loop", "models", "tool", "types", "telemetry"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""Strands AG-UI integration package. | ||
|
||
This package provides integration between Strands agents and AG-UI protocol | ||
compatible frontends, including state management and event streaming. | ||
""" | ||
|
||
from .bridge import ( | ||
AGUIEventType, | ||
StrandsAGUIBridge, | ||
StrandsAGUIEndpoint, | ||
create_strands_agui_setup, | ||
) | ||
from .state_tools import ( | ||
StrandsStateManager, | ||
emit_ui_update, | ||
get_agent_state, | ||
get_state_manager, | ||
set_agent_state, | ||
setup_agent_state_management, | ||
update_agent_state, | ||
) | ||
|
||
__all__ = [ | ||
"AGUIEventType", | ||
"StrandsAGUIBridge", | ||
"StrandsAGUIEndpoint", | ||
"create_strands_agui_setup", | ||
"StrandsStateManager", | ||
"emit_ui_update", | ||
"get_agent_state", | ||
"get_state_manager", | ||
"set_agent_state", | ||
"setup_agent_state_management", | ||
"update_agent_state", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,295 @@ | ||
"""Strands to AG-UI Protocol Bridge with State Management. | ||
|
||
This module converts Strands agent events to AG-UI protocol events, | ||
including full state management support for AG-UI compatible frontends. | ||
""" | ||
|
||
import json | ||
import logging | ||
from datetime import datetime | ||
from enum import Enum | ||
from typing import Any, AsyncIterator, Dict, List, Optional | ||
from uuid import uuid4 | ||
|
||
# Use relative import to avoid module name conflict | ||
from .state_tools import StrandsStateManager, get_state_manager, setup_agent_state_management | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AGUIEventType(str, Enum): | ||
"""AG-UI Protocol Event Types.""" | ||
|
||
# Run lifecycle events | ||
RUN_STARTED = "RUN_STARTED" | ||
RUN_FINISHED = "RUN_FINISHED" | ||
RUN_ERROR = "RUN_ERROR" | ||
|
||
# Message events | ||
TEXT_MESSAGE_START = "TEXT_MESSAGE_START" | ||
TEXT_MESSAGE_CONTENT = "TEXT_MESSAGE_CONTENT" | ||
TEXT_MESSAGE_END = "TEXT_MESSAGE_END" | ||
|
||
# Tool events | ||
TOOL_CALL_START = "TOOL_CALL_START" | ||
TOOL_CALL_ARGS = "TOOL_CALL_ARGS" | ||
TOOL_CALL_END = "TOOL_CALL_END" | ||
|
||
# State events - KEY for AG-UI compatibility! | ||
STATE_SNAPSHOT = "STATE_SNAPSHOT" | ||
STATE_DELTA = "STATE_DELTA" | ||
|
||
# Step events | ||
STEP_STARTED = "STEP_STARTED" | ||
STEP_FINISHED = "STEP_FINISHED" | ||
|
||
# Custom events | ||
CUSTOM = "CUSTOM" | ||
RAW = "RAW" | ||
|
||
|
||
class StrandsAGUIBridge: | ||
"""Bridge that converts Strands agent events to AG-UI protocol events with state management.""" | ||
|
||
def __init__(self, agent: Any, state_manager: Optional[StrandsStateManager] = None) -> None: | ||
"""Initialize the Strands AG-UI bridge. | ||
|
||
Args: | ||
agent: The Strands agent to bridge | ||
state_manager: Optional state manager, will use global one if not provided | ||
""" | ||
self.agent = agent | ||
self.state_manager = state_manager or get_state_manager() | ||
self.current_run_id: Optional[str] = None | ||
self.current_thread_id: Optional[str] = None | ||
self.current_message_id: Optional[str] = None | ||
self.active_tool_calls: Dict[str, Dict[str, Any]] = {} | ||
self.message_started = False | ||
|
||
# Set up state change callback | ||
self.state_manager.add_callback(self._on_state_change) | ||
self._state_change_queue: List[Dict[str, Any]] = [] | ||
|
||
async def stream_agui_events( | ||
self, | ||
prompt: str, | ||
thread_id: Optional[str] = None, | ||
run_id: Optional[str] = None, | ||
initial_state: Optional[Dict[str, Any]] = None, | ||
**kwargs: Any, | ||
) -> AsyncIterator[Dict[str, Any]]: | ||
"""Stream AG-UI protocol events from a Strands agent execution.""" | ||
self.current_run_id = run_id or str(uuid4()) | ||
self.current_thread_id = thread_id or str(uuid4()) | ||
self.current_message_id = str(uuid4()) | ||
self.message_started = False | ||
|
||
# Set initial state if provided | ||
if initial_state: | ||
self.state_manager.set_state(initial_state) | ||
|
||
try: | ||
# Emit run started event | ||
yield self._create_agui_event( | ||
event_type=AGUIEventType.RUN_STARTED, | ||
data={"thread_id": self.current_thread_id, "run_id": self.current_run_id}, | ||
) | ||
|
||
# Emit initial state snapshot | ||
current_state = self.state_manager.get_state() | ||
if current_state: | ||
yield self._create_agui_event(event_type=AGUIEventType.STATE_SNAPSHOT, data={"snapshot": current_state}) | ||
|
||
# Stream agent execution and convert events | ||
async for strands_event in self.agent.stream_async(prompt, **kwargs): | ||
agui_events = self._convert_strands_event_to_agui(strands_event) | ||
for agui_event in agui_events: | ||
yield agui_event | ||
|
||
# Emit any queued state changes | ||
while self._state_change_queue: | ||
state_event = self._state_change_queue.pop(0) | ||
yield state_event | ||
|
||
# Emit message end if we started one | ||
if self.message_started: | ||
yield self._create_agui_event( | ||
event_type=AGUIEventType.TEXT_MESSAGE_END, data={"message_id": self.current_message_id} | ||
) | ||
|
||
# Emit run finished event | ||
yield self._create_agui_event( | ||
event_type=AGUIEventType.RUN_FINISHED, | ||
data={"thread_id": self.current_thread_id, "run_id": self.current_run_id}, | ||
) | ||
|
||
except Exception as e: | ||
logger.error("Error in Strands-AG-UI bridge: %s", e) | ||
yield self._create_agui_event( | ||
event_type=AGUIEventType.RUN_ERROR, data={"message": str(e), "code": type(e).__name__} | ||
) | ||
|
||
def _convert_strands_event_to_agui(self, strands_event: Dict[str, Any]) -> List[Dict[str, Any]]: | ||
"""Convert a Strands event to one or more AG-UI events.""" | ||
agui_events = [] | ||
|
||
# Handle text content streaming | ||
if "data" in strands_event and strands_event["data"]: | ||
# Start message if not already started | ||
if not self.message_started: | ||
agui_events.append( | ||
self._create_agui_event( | ||
event_type=AGUIEventType.TEXT_MESSAGE_START, | ||
data={"message_id": self.current_message_id, "role": "assistant"}, | ||
) | ||
) | ||
self.message_started = True | ||
|
||
# Add content event | ||
agui_events.append( | ||
self._create_agui_event( | ||
event_type=AGUIEventType.TEXT_MESSAGE_CONTENT, | ||
data={"message_id": self.current_message_id, "delta": strands_event["data"]}, | ||
) | ||
) | ||
|
||
# Handle tool execution | ||
if "current_tool_use" in strands_event: | ||
tool_use = strands_event["current_tool_use"] | ||
if tool_use and tool_use.get("toolUseId"): | ||
tool_id = tool_use["toolUseId"] | ||
tool_name = tool_use.get("name", "unknown") | ||
|
||
# Track tool start | ||
if tool_id not in self.active_tool_calls: | ||
self.active_tool_calls[tool_id] = {"name": tool_name, "started": True} | ||
agui_events.append( | ||
self._create_agui_event( | ||
event_type=AGUIEventType.TOOL_CALL_START, | ||
data={ | ||
"tool_call_id": tool_id, | ||
"tool_call_name": tool_name, | ||
"parent_message_id": self.current_message_id, | ||
}, | ||
) | ||
) | ||
|
||
# Handle tool arguments | ||
if "input" in tool_use: | ||
tool_input = tool_use["input"] | ||
if isinstance(tool_input, dict): | ||
tool_input = json.dumps(tool_input) | ||
elif not isinstance(tool_input, str): | ||
tool_input = str(tool_input) | ||
|
||
agui_events.append( | ||
self._create_agui_event( | ||
event_type=AGUIEventType.TOOL_CALL_ARGS, data={"tool_call_id": tool_id, "delta": tool_input} | ||
) | ||
) | ||
|
||
# Handle completion events | ||
if strands_event.get("complete", False): | ||
# End any active tool calls | ||
for tool_id in list(self.active_tool_calls.keys()): | ||
agui_events.append( | ||
self._create_agui_event(event_type=AGUIEventType.TOOL_CALL_END, data={"tool_call_id": tool_id}) | ||
) | ||
del self.active_tool_calls[tool_id] | ||
|
||
return agui_events | ||
|
||
def _on_state_change(self, new_state: Dict[str, Any], updates: Dict[str, Any]) -> None: | ||
"""Handle state changes from the state manager.""" | ||
# Queue state delta event | ||
if updates: | ||
state_event = self._create_agui_event( | ||
event_type=AGUIEventType.STATE_DELTA, data={"delta": self._dict_to_json_patch(updates)} | ||
) | ||
self._state_change_queue.append(state_event) | ||
|
||
def _dict_to_json_patch(self, updates: Dict[str, Any]) -> List[Dict[str, Any]]: | ||
"""Convert dictionary updates to JSON Patch format.""" | ||
patches = [] | ||
for key, value in updates.items(): | ||
if value is None: | ||
patches.append({"op": "remove", "path": f"/{key}"}) | ||
else: | ||
patches.append({"op": "replace", "path": f"/{key}", "value": value}) | ||
return patches | ||
|
||
def _create_agui_event(self, event_type: AGUIEventType, data: Dict[str, Any]) -> Dict[str, Any]: | ||
"""Create a properly formatted AG-UI protocol event.""" | ||
return {"type": event_type.value, "timestamp": int(datetime.now().timestamp() * 1000), **data} | ||
|
||
|
||
class StrandsAGUIEndpoint: | ||
"""HTTP endpoint that serves AG-UI events from Strands agents with state management.""" | ||
|
||
def __init__(self, agents: Dict[str, Any]) -> None: | ||
"""Initialize the Strands AG-UI endpoint. | ||
|
||
Args: | ||
agents: Dictionary of agent name to agent instance | ||
""" | ||
self.agents = agents | ||
self.bridges: Dict[str, StrandsAGUIBridge] = {} | ||
|
||
# Create bridges for each agent | ||
for name, agent in agents.items(): | ||
self.bridges[name] = StrandsAGUIBridge(agent) | ||
|
||
async def handle_request(self, request_data: Dict[str, Any]) -> AsyncIterator[str]: | ||
"""Handle AG-UI protocol HTTP request and stream SSE responses.""" | ||
agent_name = request_data.get("agent") | ||
messages = request_data.get("messages", []) | ||
thread_id = request_data.get("threadId") | ||
run_id = request_data.get("runId") | ||
frontend_state = request_data.get("state", {}) | ||
|
||
if not agent_name or agent_name not in self.agents: | ||
yield f"data: {json.dumps({'type': 'RUN_ERROR', 'message': 'Agent not found'})}\n\n" | ||
return | ||
|
||
bridge = self.bridges[agent_name] | ||
|
||
# Extract the latest user message | ||
user_messages = [msg for msg in messages if msg.get("role") == "user"] | ||
if not user_messages: | ||
yield f"data: {json.dumps({'type': 'RUN_ERROR', 'message': 'No user message found'})}\n\n" | ||
return | ||
|
||
latest_message = user_messages[-1] | ||
latest_prompt = latest_message.get("content", "") | ||
if isinstance(latest_prompt, list) and latest_prompt: | ||
latest_prompt = latest_prompt[0].get("text", "") | ||
|
||
try: | ||
async for event in bridge.stream_agui_events( | ||
prompt=latest_prompt, thread_id=thread_id, run_id=run_id, initial_state=frontend_state | ||
): | ||
yield f"data: {json.dumps(event)}\n\n" | ||
except Exception as e: | ||
logger.error("Error in AG-UI endpoint: %s", e) | ||
error_event = {"type": "RUN_ERROR", "message": str(e), "code": type(e).__name__} | ||
yield f"data: {json.dumps(error_event)}\n\n" | ||
|
||
|
||
def create_strands_agui_setup( | ||
agents: Dict[str, Any], initial_states: Optional[Dict[str, Dict[str, Any]]] = None | ||
) -> StrandsAGUIEndpoint: | ||
"""Create a complete Strands + AG-UI setup with state management. | ||
|
||
Args: | ||
agents: Dictionary mapping agent names to agent instances | ||
initial_states: Optional dictionary mapping agent names to their initial states | ||
|
||
Returns: | ||
A configured StrandsAGUIEndpoint instance | ||
""" | ||
# Set up state management for each agent | ||
for name, agent in agents.items(): | ||
initial_state = initial_states.get(name, {}) if initial_states else {} | ||
setup_agent_state_management(agent, initial_state) | ||
|
||
return StrandsAGUIEndpoint(agents) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like these are already defined in the AG-UI Python SDK - should we use that Enum instead: https://github.com/ag-ui-protocol/ag-ui/blob/d53b012ff8051420905d85bb1a443c2729616d88/python-sdk/ag_ui/core/events.py#L29
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, but I'd assumed the team would want to avoid having all strands agents depend on ag_ui (it's barely older than strands itself, and not all agents will have a web-based UI). Would you prefer: 1) Taking a hard-dependency on ag_ui and adding a bridge directly in the core package, or 2) Having an ag_ui example that could potentially turn into a strands-ui package if it succeeds?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree on avoiding a direct dependency - at this point we'd be looking towards an example/sample.
There's still a lot of open questions around the other features (#33, #31, multi-agent) which the team is starting work towards and knowing the gaps we have is really useful - but it's still too early for a direct integration.