Skip to content

Commit

Permalink
Various bugfixes when putting together the demo
Browse files Browse the repository at this point in the history
Signed-off-by: JonahSussman <sussmanjonah@gmail.com>
  • Loading branch information
JonahSussman committed Aug 21, 2024
1 parent 554a3e6 commit 5e078d7
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 22 deletions.
4 changes: 2 additions & 2 deletions kai/data/templates/solution_handling/before_and_after.jinja
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
Solution before changes:
Solved example before changes:
```
{{ solution.original_code }}
```

Solution after changes:
Solved example after changes:
```
{{ solution.updated_code }}
```
2 changes: 1 addition & 1 deletion kai/data/templates/solution_handling/diff_only.jinja
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Solution diff:
Solved example diff:
```diff
{{ solution.file_diff }}
2 changes: 1 addition & 1 deletion kai/data/templates/solution_handling/llm_summary.jinja
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Summary of changes for solution:
Summary of changes for solved example:

{{ solution.llm_summary }}
6 changes: 5 additions & 1 deletion kai/models/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
# These are known unique variables that can be included by incidents
# They would prevent matches that we actually want, so we filter them
# before adding to the database or searching
FILTERED_INCIDENT_VARS = ("file", "package")
FILTERED_INCIDENT_VARS = [
"file", # Java, URI of the offending file
"package", # Java, shows the package
"name", # Java, shows the name of the method that caused the incident
]


def remove_known_prefixes(path: str) -> str:
Expand Down
5 changes: 3 additions & 2 deletions kai/routes/load_analysis_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ class PostLoadAnalysisReportApplication(BaseModel):


class PostLoadAnalysisReportParams(BaseModel):
path_to_report: str
application: PostLoadAnalysisReportApplication
report_data: dict | list[dict]
report_id: str


@to_route("post", "/load_analysis_report")
async def post_load_analysis_report(request: Request):
params = PostLoadAnalysisReportParams.model_validate(await request.json())

application = Application(**params.application.model_dump())
report = Report.load_report_from_file(params.path_to_report)
report = Report(params.report_data, params.report_id)

count = request.app["kai_application"].incident_store.load_report(
application, report
Expand Down
18 changes: 14 additions & 4 deletions kai/service/incident_store/incident_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def __init__(
self.solution_detector = solution_detector
self.solution_producer = solution_producer

self.create_tables() # This is a no-op if the tables already exist

def load_report(self, app: Application, report: Report) -> tuple[int, int, int]:
"""
Load incidents from a report and given application object. Returns a
Expand Down Expand Up @@ -273,6 +275,7 @@ def load_report(self, app: Application, report: Report) -> tuple[int, int, int]:
session.commit()

for incident in violation_obj.incidents:
filtered_vars = filter_incident_vars(incident.variables)
report_incidents.append(
SQLIncident(
violation_name=violation.violation_name,
Expand All @@ -281,7 +284,7 @@ def load_report(self, app: Application, report: Report) -> tuple[int, int, int]:
incident_uri=incident.uri,
incident_snip=incident.code_snip,
incident_line=incident.line_number,
incident_variables=deep_sort(incident.variables),
incident_variables=deep_sort(filtered_vars),
incident_message=incident.message,
)
)
Expand Down Expand Up @@ -401,9 +404,7 @@ def find_solutions(
)

result: list[Solution] = []
for incident in session.execute(
select_incidents_with_solutions_stmt
).scalars():
for incident in session.scalars(select_incidents_with_solutions_stmt).all():
select_accepted_solution_stmt = select(SQLAcceptedSolution).where(
SQLAcceptedSolution.solution_id == incident.solution_id
)
Expand All @@ -415,9 +416,18 @@ def find_solutions(
processed_solution = self.solution_producer.post_process_one(
incident, accepted_solution.solution
)

# TODO: This first line doesn't work for some reason. The second
# line is a hack to get around it.
accepted_solution.solution = processed_solution
session.query(SQLAcceptedSolution).filter(
SQLAcceptedSolution.solution_id == incident.solution_id
).update({"solution": processed_solution})

result.append(processed_solution)

session.commit()

session.commit()

return result
Expand Down
2 changes: 1 addition & 1 deletion kai/service/incident_store/sql_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class SQLSolutionType(TypeDecorator):
impl = VARCHAR
cache_ok = True
cache_ok = False

def process_bind_param(self, value: Optional[Solution], dialect: Dialect):
# Into the db
Expand Down
7 changes: 4 additions & 3 deletions kai/service/kai_application/kai_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ def get_incident_solutions_for_file(
)

if len(solutions) != 0:
pb_incident["solution_str"] = self.solution_consumer(
solutions[0]
)
solution_str = self.solution_consumer(solutions[0])

if len(solution_str) != 0:
pb_incident["solution_str"] = solution_str

pb_vars = {
"src_file_name": file_name,
Expand Down
55 changes: 54 additions & 1 deletion kai/service/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Any

from genai import Client, Credentials
from genai.extensions.langchain.chat_llm import LangChainChatInterface
Expand Down Expand Up @@ -100,7 +101,22 @@ def __init__(self, config: KaiConfigModels):
model_class = FakeListChatModel

defaults = {
"responses": ["Default LLM response."],
"responses": [
"## Reasoning\n"
"\n"
"Default reasoning.\n"
"\n"
"## Updated File\n"
"\n"
"```\n"
"Default updated file.\n"
"```\n"
"\n"
"## Additional Information\n"
"\n"
"Default additional information.\n"
"\n"
],
"sleep": None,
}

Expand Down Expand Up @@ -137,3 +153,40 @@ def __init__(self, config: KaiConfigModels):
]
else:
self.llama_header = config.llama_header


class Wrapper:
"""
Wrapper class to intercept and log all attribute access and method calls on
an object.
"""

def __init__(self, obj: Any):
self.obj: Any = obj
self.callable_results: list = []

def __getattr__(self, attr: Any):
print(f"Getting {type(self.obj).__name__}.{attr}")

result = getattr(self.obj, attr)
if callable(result):
return self.CallableWrapper(self, result)

return result

class CallableWrapper:
def __init__(self, parent: "Wrapper", callable: Any):
self.parent = parent
self.callable = callable

def __call__(self, *args, **kwargs):
print(f"Calling {type(self.parent.obj).__name__}.{self.callable.__name__}")

for i, arg in enumerate(args):
print(f" arg {i}: {arg}")
for key, value in kwargs.items():
print(f" {key}: {value}")

result = self.callable(*args, **kwargs)
self.parent.callable_results.append(result)
return result
3 changes: 3 additions & 0 deletions kai/service/solution_handling/consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def solution_consumer_before_and_after(solution: Solution) -> str:


def solution_consumer_llm_summary(solution: Solution) -> str:
if solution.llm_summary is None:
return ""

return (
__create_jinja_env().get_template("llm_summary.jinja").render(solution=solution)
)
Expand Down
14 changes: 8 additions & 6 deletions kai/service/solution_handling/production.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@ class SolutionProducerTextOnly(SolutionProducer):
def produce_one(
self, incident: SQLIncident, repo: Repo, old_commit: str, new_commit: str
) -> Solution:
local_file_path = remove_known_prefixes(
unquote(urlparse(incident.incident_uri).path)
)
file_path = os.path.join(
repo.working_tree_dir,
remove_known_prefixes(unquote(urlparse(incident.incident_uri).path)),
local_file_path,
)

# NOTE: `repo_diff` functionality is not implemented
Expand All @@ -77,7 +80,7 @@ def produce_one(
# probably a better way to handle this.
try:
original_code = (
repo.git.show(f"{new_commit}:{file_path}")
repo.git.show(f"{old_commit}:{local_file_path}")
.encode("utf-8")
.decode("utf-8")
)
Expand All @@ -86,7 +89,7 @@ def produce_one(

try:
updated_code = (
repo.git.show(f"{new_commit}:{file_path}")
repo.git.show(f"{new_commit}:{local_file_path}")
.encode("utf-8")
.decode("utf-8")
)
Expand All @@ -113,13 +116,12 @@ def post_process_one(self, incident: SQLIncident, solution: Solution) -> Solutio
class SolutionProducerLLMLazy(SolutionProducer):
def __init__(self, model_provider: ModelProvider):
self.model_provider = model_provider
self.text_only = SolutionProducerTextOnly()

def produce_one(
self, incident: SQLIncident, repo: Repo, old_commit: str, new_commit: str
) -> Solution:
solution = SolutionProducerTextOnly().produce_one(
incident, repo, old_commit, new_commit
)
solution = self.text_only.produce_one(incident, repo, old_commit, new_commit)

solution.llm_summary_generated = False

Expand Down

0 comments on commit 5e078d7

Please sign in to comment.