Skip to content

Commit

Permalink
[BUG] Data Visualization Agent not fixing code when run inside SQL Da…
Browse files Browse the repository at this point in the history
…ta Analyst Agent #32
  • Loading branch information
mdancho84 committed Jan 18, 2025
1 parent e1ba000 commit 627027b
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions ai_data_science_team/multiagents/sql_data_analyst.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def update_params(self, **kwargs):
self._params[k] = v
self._compiled_graph = self._make_compiled_graph()

def ainvoke_agent(self, user_instructions, **kwargs):
def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
"""
Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
Expand Down Expand Up @@ -146,21 +146,27 @@ def ainvoke_agent(self, user_instructions, **kwargs):
"""
response = self._compiled_graph.ainvoke({
"user_instructions": user_instructions,
"max_retries": max_retries,
"retry_count": retry_count,
}, **kwargs)

if response.get("messages"):
response["messages"] = remove_consecutive_duplicates(response["messages"])

self.response = response

def invoke_agent(self, user_instructions, **kwargs):
def invoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
"""
Invokes the SQL Data Analyst Multi-Agent.
Parameters:
----------
user_instructions: str
The user's instructions for the combined SQL and (optionally) Data Visualization agents.
max_retries (int):
Maximum retry attempts for cleaning.
retry_count (int):
Current retry attempt.
**kwargs:
Additional keyword arguments to pass to the compiled graph's `invoke` method.
Expand Down Expand Up @@ -208,6 +214,8 @@ def invoke_agent(self, user_instructions, **kwargs):
"""
response = self._compiled_graph.invoke({
"user_instructions": user_instructions,
"max_retries": max_retries,
"retry_count": retry_count,
}, **kwargs)

if response.get("messages"):
Expand Down Expand Up @@ -341,6 +349,8 @@ class PrimaryState(TypedDict):
plot_required: bool
data_visualization_function: str
plotly_graph: dict
max_retries: int
retry_count: int

def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:

Expand Down

0 comments on commit 627027b

Please sign in to comment.