Skip to content

Commit

Permalink
Fixing up the symbol not found and dependency agents
Browse files Browse the repository at this point in the history
* Make it so that the FQDN agent doesn't cause LLM to cycle on itself
* Make it so the FQDN agent has a system prompt, that will help with
  context window issues
* Make it so that the SymbolNotFound issues go to the maven compiler.
  This is because if the dependency is missing from the project both a
symbol not found and a package does not exist are created. If just a
symbol not found, then the issue is the Class changed.

Signed-off-by: Shawn Hurley <shawn@hurley.page>
  • Loading branch information
shawn-hurley committed Feb 14, 2025
1 parent 389283a commit 4977d3e
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 47 deletions.
4 changes: 4 additions & 0 deletions .trunk/configs/custom-words.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,7 @@ upperbound
venv
webassets
webmvc
httpsnoop
felixge
cenkalti
tmpl
6 changes: 5 additions & 1 deletion kai/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from opentelemetry import trace
from pydantic.v1.utils import deep_update

from kai.cache import Cache, CachePathResolver, SimplePathResolver
from kai.kai_config import KaiConfigModels
from kai.logging.logging import get_logger

LOG = get_logger(__name__)
tracer = trace.get_tracer("model_provider")


class ModelProvider:
Expand Down Expand Up @@ -211,6 +213,7 @@ def challenge(k: str) -> BaseMessage:
elif isinstance(self.llm, ChatDeepSeek):
challenge("max_tokens")

@tracer.start_as_current_span("invoke_llm")
def invoke(
self,
input: LanguageModelInput,
Expand All @@ -221,6 +224,8 @@ def invoke(
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> BaseMessage:
span = trace.get_current_span()
span.set_attribute("model", self.model_id)
# Some fields can only be configured when the model is instantiated.
# This side-steps that by creating a new instance of the model with the
# configurable fields set, then invoking that new instance.
Expand Down Expand Up @@ -258,5 +263,4 @@ def invoke(
# only raise an exception when we are in demo mode
if self.demo_mode:
raise e

return response
5 changes: 3 additions & 2 deletions kai/reactive_codeplanner/agent/analyzer_fix/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pygments.util import ClassNotFound

from kai.llm_interfacing.model_provider import ModelProvider
from kai.logging.logging import get_logger
from kai.logging.logging import TRACE, get_logger
from kai.reactive_codeplanner.agent.analyzer_fix.api import (
AnalyzerFixRequest,
AnalyzerFixResponse,
Expand Down Expand Up @@ -239,7 +239,8 @@ def guess_language(code: str, filename: Optional[str] = None) -> str:

if largest_rv_lexer:
# Remove all the extra information after the + sign
logger.debug(
logger.log(
TRACE,
"finding lexer %s, lexer aliases: %s",
lexer_found,
lexer_found.aliases,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Optional

from jinja2 import Template
from langchain_core.messages import HumanMessage
from langchain_core.messages import HumanMessage, SystemMessage

from kai.llm_interfacing.model_provider import ModelProvider
from kai.logging.logging import get_logger
Expand All @@ -26,6 +26,22 @@ class FQDNDependencySelectorRequest(AgentRequest):
query: list[str]
times: int

def from_selector_request(
ask: "FQDNDependencySelectorRequest", new_query: list[str]
) -> "FQDNDependencySelectorRequest":
"""Will create a new selector request bumping the times by one"""
req = FQDNDependencySelectorRequest(
ask.file_path,
msg=ask.msg,
code="",
query=new_query,
times=ask.times + 1,
task=ask.task,
background=ask.background,
)
req.cache_path_resolver = ask.cache_path_resolver
return req


@dataclass
class FQDNDependencySelectorResult(AgentResult):
Expand All @@ -42,22 +58,22 @@ class __llm_response:
def __init__(self, model_provider: ModelProvider) -> None:
self._model_provider = model_provider

message = Template(
system_tmpl = Template(
"""
You are an excellent Java developer with expertise in dependency management.
Given an initial Maven compiler and a list of attempted searches, provide an updated dependency to use.
Do not use a dependency that has already been tried.
Only after all the dependencies have been tried, say only the word "TERMINATE", wait for the list of tried dependencies to have all the tries.
Think through the problem fully. Do not update the dependency if it has moved to newer versions; we want to find the version that matches, regardless of whether it is old or not.
Output in the format of:
Think through the problem. But only give one dependency to try at a time. You should only output "TERMINATE" or in the following format:
Reasoning
ArtifactId:
GroupId:
"""
)
message = Template(
"""
{{message}}
Searched dependencies:
Expand All @@ -80,7 +96,10 @@ def execute(self, ask: AgentRequest) -> FQDNDependencySelectorResult:
query = []
query.append(get_maven_query_from_code(ask.code))

msg = [HumanMessage(content=self.message.render(message=ask.msg, query=query))]
msg = [
SystemMessage(content=self.system_tmpl.render()),
HumanMessage(content=self.message.render(message=ask.msg, query=query)),
]
fix_gen_response = self._model_provider.invoke(msg, ask.cache_path_resolver)
llm_response = self.parse_llm_response(fix_gen_response.content)
# Really we need to re-call the agent
Expand All @@ -99,15 +118,7 @@ def execute(self, ask: AgentRequest) -> FQDNDependencySelectorResult:
## need to recursively call execute.
query.append(new_query)
return self.execute(
FQDNDependencySelectorRequest(
ask.file_path,
msg=ask.msg,
code="",
query=query,
times=ask.times + 1,
task=ask.task,
background=ask.background,
)
FQDNDependencySelectorRequest.from_selector_request(ask, query)
)
if isinstance(response, list):
response = None
Expand All @@ -125,6 +136,8 @@ def parse_llm_response(
reasoning_str = ""
artifact_id = ""
group_id = ""
if "TERMINATE" in content:
return None
for line in content.splitlines():
if not line:
continue
Expand Down
2 changes: 0 additions & 2 deletions kai/reactive_codeplanner/agent/maven_compiler_fix/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ class MavenCompilerAgent(Agent):
"""{{ background }}
I will give you compiler errors and the offending line of code, and you will need to use the file to determine how to fix them. You should only use compiler errors to determine what to fix.
Make sure that the references to any changed types are kept.
You must reason through the required changes and rewrite the Java file to make it compile.
You will then provide an step-by-step explanation of the changes required so that someone could recreate it in a similar situation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DependencyResolutionError,
MavenCompilerError,
OtherError,
SymbolNotFoundError,
SyntaxError,
TypeMismatchError,
)
Expand Down Expand Up @@ -52,6 +53,7 @@ class MavenCompilerTaskRunner(TaskRunner):
AccessControlError,
OtherError,
DependencyResolutionError,
SymbolNotFoundError,
)

def __init__(self, agent: MavenCompilerAgent) -> None:
Expand Down
8 changes: 0 additions & 8 deletions kai/reactive_codeplanner/task_runner/dependency/api.py

This file was deleted.

26 changes: 10 additions & 16 deletions kai/reactive_codeplanner/task_runner/dependency/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
from kai.reactive_codeplanner.task_runner.api import TaskRunner
from kai.reactive_codeplanner.task_runner.compiler.maven_validator import (
PackageDoesNotExistError,
SymbolNotFoundError,
)
from kai.reactive_codeplanner.task_runner.dependency.api import (
DependencyValidationError,
)
from kai.reactive_codeplanner.vfs.git_vfs import RepoContextManager

Expand All @@ -36,11 +32,7 @@ class DependencyTaskResponse:
class DependencyTaskRunner(TaskRunner):
"""TODO: Add Class Documentation"""

handled_type = (
DependencyValidationError,
SymbolNotFoundError,
PackageDoesNotExistError,
)
handled_type = (PackageDoesNotExistError,)

def __init__(self, agent: MavenDependencyAgent) -> None:
self._agent = agent
Expand All @@ -50,15 +42,11 @@ def can_handle_task(self, task: Task) -> bool:

@tracer.start_as_current_span("dependency_task_execute")
def execute_task(self, rcm: RepoContextManager, task: Task) -> TaskResult:
if not isinstance(task, self.handled_type):
if not isinstance(task, PackageDoesNotExistError):
logger.error("Unexpected task type %r", task)
return TaskResult(encountered_errors=[], modified_files=[], summary="")

msg = task.message
if isinstance(task, PackageDoesNotExistError) or isinstance(
task, SymbolNotFoundError
):
msg = f"Maven Compiler Error:\n{task.compiler_error_message()}"
msg = f"Maven Compiler Error:\n{task.compiler_error_message()}"

maven_dep_response = self._agent.execute(
MavenDependencyRequest(
Expand All @@ -78,7 +66,13 @@ def execute_task(self, rcm: RepoContextManager, task: Task) -> TaskResult:
"No final answer was given, we need to return with nothing modified. result: %r",
maven_dep_response,
)
return TaskResult(encountered_errors=[], modified_files=[], summary="")
return TaskResult(
encountered_errors=[
f"unable to fix compiler message: {msg} no dependency found"
],
modified_files=[],
summary="",
)

if not maven_dep_response.fqdn_response:
logger.info(
Expand Down

0 comments on commit 4977d3e

Please sign in to comment.