Skip to content

Commit

Permalink
✨ Agents now take parent task history into account (#491)
Browse files Browse the repository at this point in the history
* ✨ [WIP] Agents now take parent task history into account

- The AnalyzerAgent no longer uses hardcoded language, source, and target, but instead gathers it from the task context

Signed-off-by: Fabian von Feilitzsch <fabian@fabianism.us>

* 🐛 minor fixes

Signed-off-by: Pranav Gaikwad <pgaikwad@redhat.com>

* remove creation_order from task

Signed-off-by: Pranav Gaikwad <pgaikwad@redhat.com>

* 🐛 fix parsing error in analysis agent

Signed-off-by: Pranav Gaikwad <pgaikwad@redhat.com>

* 🐛 minor fixes

Signed-off-by: Pranav Gaikwad <pgaikwad@redhat.com>

* 🐛 fix formatting, fix agents

Signed-off-by: Pranav Gaikwad <pgaikwad@redhat.com>

* 🐛 fix bug with dependency agent

Signed-off-by: Pranav Gaikwad <pgaikwad@redhat.com>

* 👻 trunk fixes

Signed-off-by: Pranav Gaikwad <pgaikwad@redhat.com>

---------

Signed-off-by: Fabian von Feilitzsch <fabian@fabianism.us>
Signed-off-by: Pranav Gaikwad <pgaikwad@redhat.com>
Co-authored-by: Pranav Gaikwad <pgaikwad@redhat.com>
  • Loading branch information
fabianvf and pranavgaikwad authored Feb 6, 2025
1 parent 6073542 commit 0abd4fc
Show file tree
Hide file tree
Showing 19 changed files with 376 additions and 64 deletions.
18 changes: 12 additions & 6 deletions kai/reactive_codeplanner/agent/analyzer_fix/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
from dataclasses import dataclass
from logging import DEBUG
from pathlib import Path
Expand Down Expand Up @@ -32,6 +33,7 @@ class AnalyzerAgent(Agent):
system_message_template = Template(
"""
You are an experienced {{ language }} developer, who specializes in migrating code from {{ source }} to {{ target }}
{{ background }}
"""
)

Expand Down Expand Up @@ -123,6 +125,7 @@ def execute(self, ask: AgentRequest) -> AnalyzerFixResponse:
language=language,
source=source,
target=target,
background=ask.background,
)
)

Expand All @@ -141,7 +144,7 @@ def execute(self, ask: AgentRequest) -> AnalyzerFixResponse:
ask.cache_path_resolver,
)

resp = self.parse_llm_response(ai_message, language)
resp = self.parse_llm_response(ai_message)
return AnalyzerFixResponse(
encountered_errors=[],
file_to_modify=Path(os.path.abspath(ask.file_path)),
Expand All @@ -150,7 +153,7 @@ def execute(self, ask: AgentRequest) -> AnalyzerFixResponse:
updated_file_content=resp.source_file,
)

def parse_llm_response(self, message: BaseMessage, language: str) -> _llm_response:
def parse_llm_response(self, message: BaseMessage) -> _llm_response:
"""Private method that will be used to parse the contents and get the results"""

lines_of_output = cast(str, message.content).splitlines()
Expand All @@ -162,23 +165,26 @@ def parse_llm_response(self, message: BaseMessage, language: str) -> _llm_respon
reasoning = ""
additional_details = ""
for line in lines_of_output:
if line.strip() == "## Reasoning":
# trunk-ignore(cspell/error)
if re.match(r"(?:##|\*\*)\s+[Rr]easoning", line.strip()):
in_reasoning = True
in_source_file = False
in_additional_details = False
continue
if line.strip() == f"## Updated {language} File":
# trunk-ignore(cspell/error)
if re.match(r"(?:##|\*\*)\s+[Uu]pdated.*[Ff]ile", line.strip()):
in_source_file = True
in_reasoning = False
in_additional_details = False
continue
if "## Additional Information" in line.strip():
# trunk-ignore(cspell/error)
if re.match(r"(?:##|\*\*)\s+[Aa]dditional\s+[Ii]nformation", line.strip()):
in_reasoning = False
in_source_file = False
in_additional_details = True
continue
if in_source_file:
if f"```{language}" in line or "```" in line:
if re.match(r"```(?:\w*)", line):
continue
source_file = "\n".join([source_file, line])
if in_reasoning:
Expand Down
1 change: 1 addition & 0 deletions kai/reactive_codeplanner/agent/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class AgentRequest:
file_path: Path
task: Task
background: str
cache_path_resolver: CachePathResolver = field(init=False)

def __post_init__(self) -> None:
Expand Down
32 changes: 22 additions & 10 deletions kai/reactive_codeplanner/agent/dependency_agent/dependency_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import Any, Callable, Optional, TypedDict, Union

from jinja2 import Template
from langchain.prompts.chat import HumanMessagePromptTemplate
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

Expand Down Expand Up @@ -53,9 +54,10 @@ class _llm_response:

class MavenDependencyAgent(Agent):

sys_msg = SystemMessage(
system_message_template = Template(
"""
You are an excellent java developer focused on updating dependencies in a maven `pom.xml` file.
{{ background }}
### Guidelines:
1 Only use the provided and predefined functions as the functions. Do not use any other functions.
Expand Down Expand Up @@ -115,7 +117,7 @@ class MavenDependencyAgent(Agent):
```
Observation: We now have the fqdn for the commons-collections4 dependency
Though: Now I have the latest version information I need to find the where guava is in the file to replace it.
Thought: Now I have the latest version information I need to find the where guava is in the file to replace it.
Action: ```python
start_line, end_line = find_in_pom._run(relative_file_path="module/file.py", keywords={"groupId": "com.google.guava"", "artifactId": "guava")
```
Expand All @@ -136,7 +138,6 @@ class MavenDependencyAgent(Agent):

inst_msg_template = HumanMessagePromptTemplate.from_template(
"""
[INST]
Given the message, you should determine the dependency that needs to be changed.
You must use the following format:
Expand Down Expand Up @@ -190,7 +191,13 @@ def execute(self, ask: AgentRequest) -> AgentResult:
if not request.message:
return AgentResult()

msg = [self.sys_msg, self.inst_msg_template.format(message=request.message)]
system_message = SystemMessage(
content=self.system_message_template.render(background=ask.background)
)

content = self.inst_msg_template.format(message=request.message)

msg = [system_message, content]
fix_gen_attempts = 0
llm_response: Optional[_llm_response] = None
maven_search: Optional[FQDNResponse] = None
Expand Down Expand Up @@ -238,8 +245,11 @@ def execute(self, ask: AgentRequest) -> AgentResult:
)
if to_llm_message is not None and callable(to_llm_message):
tool_outputs.append(method_out.to_llm_message().content)

msg.append(HumanMessage(content="\n".join(tool_outputs)))
if tool_outputs:
msg.append(HumanMessage(content="\n".join(tool_outputs)))
else:
# we cannot continue the chat when we dont have any tool outputs
break

if llm_response is None or fix_gen_response is None:
return AgentResult()
Expand All @@ -256,12 +266,13 @@ def execute(self, ask: AgentRequest) -> AgentResult:
logger.info("Need to call sub-agent for selecting FQDN")
r = self.child_agent.execute(
FQDNDependencySelectorRequest(
request.file_path,
ask.task,
request.message,
a.code,
file_path=request.file_path,
task=ask.task,
msg=request.message,
code=a.code,
query=[],
times=0,
background=ask.background,
)
)
if r.response is not None and isinstance(r.response, list):
Expand Down Expand Up @@ -354,6 +365,7 @@ def parse_llm_response(
code_block = ""
thought_str = ""
observation_str = ""
final_answer = " ".join(parts[1:]).strip()
in_final_answer = True
in_code = False
in_thought = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def execute(self, ask: AgentRequest) -> FQDNDependencySelectorResult:
query=query,
times=ask.times + 1,
task=ask.task,
background=ask.background,
)
)
if isinstance(response, list):
Expand Down
23 changes: 16 additions & 7 deletions kai/reactive_codeplanner/agent/maven_compiler_fix/agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

from jinja2 import Template
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage

Expand All @@ -13,8 +15,8 @@


class MavenCompilerAgent(Agent):
system_message = SystemMessage(
content="""
system_message_template = Template(
"""{{ background }}
I will give you compiler errors and the offending line of code, and you will need to use the file to determine how to fix them. You should only use compiler errors to determine what to fix.
Make sure that the references to any changed types are kept.
Expand Down Expand Up @@ -65,13 +67,17 @@ def execute(self, ask: AgentRequest) -> AgentResult:
)
return MavenCompilerAgentResult()

system_message = SystemMessage(
content=self.system_message_template.render(background=ask.background)
)

compile_errors = f"Line of code: {line_of_code};\n{ask.message}"
content = self.chat_message_template.render(
src_file_contents=ask.file_contents, compile_errors=compile_errors
)

ai_message = self.model_provider.invoke(
[self.system_message, HumanMessage(content=content)],
[system_message, HumanMessage(content=content)],
ask.cache_path_resolver,
)

Expand All @@ -97,23 +103,26 @@ def parse_llm_response(self, message: BaseMessage) -> MavenCompilerAgentResult:
reasoning = ""
additional_details = ""
for line in lines_of_output:
if line.strip() == "## Updated Java File":
# trunk-ignore(cspell/error)
if re.match(r"(?:##|\*\*)\s+[Uu]pdated.*[Ff]ile", line.strip()):
in_java_file = True
in_reasoning = False
in_additional_details = False
continue
if line.strip() == "## Reasoning":
# trunk-ignore(cspell/error)
if re.match(r"(?:##|\*\*)\s+[Rr]easoning", line.strip()):
in_java_file = False
in_reasoning = True
in_additional_details = False
continue
if line.strip() == "## Additional Information (optional)":
# trunk-ignore(cspell/error)
if re.match(r"(?:##|\*\*)\s+[Aa]dditional\s+[Ii]nformation", line.strip()):
in_reasoning = False
in_java_file = False
in_additional_details = True
continue
if in_java_file:
if "```java" in line or "```" in line or line == "\n":
if re.match(r"```(?:\w*)", line):
continue
java_file = "\n".join([java_file, line]).strip()
if in_reasoning:
Expand Down
1 change: 1 addition & 0 deletions kai/reactive_codeplanner/agent/maven_compiler_fix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@ def to_reflection_task(self) -> Optional[ReflectionTask]:
updated_file_contents=self.updated_file_contents,
original_file_contents=self.original_file,
task=self.task,
background="",
)
2 changes: 1 addition & 1 deletion kai/reactive_codeplanner/agent/reflection_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def _parse_llm_response(
if isinstance(content, list):
return None
match_updated_file = re.search(
r"[##|\*\*] [U|u]pdated [F|f]ile\s+.*?```\w+\n([\s\S]*?)```", # trunk-ignore(cspell)
r"(?:##|\*\*)\s+[Uu]pdated.*[Ff]ile\s+.*?```\w+\n([\s\S]*?)```", # trunk-ignore(cspell)
content,
re.DOTALL,
)
Expand Down
1 change: 1 addition & 0 deletions kai/reactive_codeplanner/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def main() -> None:
task_manager.supply_result(result)
except Exception as e:
logger.error("Failed to supply result %s: %s", result, e)

supply_time = time.time() - start_supply_time
logger.info("PERFORMANCE: %.6f seconds to supply result", supply_time)

Expand Down
14 changes: 14 additions & 0 deletions kai/reactive_codeplanner/task_manager/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class Task:
children: list["Task"] = field(default_factory=list, compare=False)
retry_count: int = 0
max_retries: int = 3
result: Optional["TaskResult"] = None

_creation_counter = 0

def oldest_ancestor(self) -> "Task":
if self.parent:
Expand Down Expand Up @@ -67,6 +70,10 @@ def truncate(value: str, max_length: int = 30) -> str:

return f"{class_name}<" + ", ".join(field_strings) + ">"

def background(self) -> str:
"""Used by Agents to provide context when solving child issues"""
raise NotImplementedError

__repr__ = __str__


Expand Down Expand Up @@ -130,6 +137,12 @@ def __str__(self) -> str:

return f"{self.__class__.__name__}<loc={self.file}:{self.line}:{self.column}, message={self.message}>(priority={self.priority}({shadowed_priority}), depth={self.depth}, retries={self.retry_count})"

def background(self) -> str:
"""Used by Agents to provide context when solving child issues"""
if self.parent is None:
return ""
return self.oldest_ancestor().background()

__repr__ = __str__


Expand All @@ -138,6 +151,7 @@ def __str__(self) -> str:
class TaskResult:
encountered_errors: list[str]
modified_files: list[Path]
summary: str


@dataclass
Expand Down
5 changes: 4 additions & 1 deletion kai/reactive_codeplanner/task_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,12 @@ def execute_task(self, task: Task) -> TaskResult:
result = agent.execute_task(self.rcm, task)
except Exception as e:
logger.exception("Unhandled exception executing task %s", task)
result = TaskResult(encountered_errors=[str(e)], modified_files=[])
result = TaskResult(
encountered_errors=[str(e)], modified_files=[], summary=""
)

logger.debug("Task execution result: %s", result)
task.result = result
return result

def get_agent_for_task(self, task: Task) -> TaskRunner:
Expand Down
16 changes: 16 additions & 0 deletions kai/reactive_codeplanner/task_runner/analyzer_lsp/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ def targets(self) -> list[str]:
target.sort()
return target

def background(self) -> str:
if self.parent is not None:
return self.oldest_ancestor().background()
if self.children:
message = f"""You attempted to solve the following issues in the source code you are migrating:
Issues: {"\n".join(list(set(self.incident_message)))}"""

# TODO(pgaikwad): we need to ensure this doesn't confuse the agents more than it helps before adding it back
# if self.result and self.result.summary:
# message += f"\n\nHere is the reasoning you provided for your initial solution:\n\n{self.result.summary}"

message += "\n\nHowever your solution caused additional problems elsewhere in the repository."
return message

return ""

@cached_property
def incident_message(self) -> list[str]:
incident_msg_list = [i.message for i in self.incidents]
Expand Down
Loading

0 comments on commit 0abd4fc

Please sign in to comment.