From 5e078d707b0ee886ab7e048ea323462697677d79 Mon Sep 17 00:00:00 2001 From: JonahSussman Date: Wed, 21 Aug 2024 03:08:25 -0400 Subject: [PATCH] Various bugfixes when putting together the demo Signed-off-by: JonahSussman --- .../solution_handling/before_and_after.jinja | 4 +- .../solution_handling/diff_only.jinja | 2 +- .../solution_handling/llm_summary.jinja | 2 +- kai/models/util.py | 6 +- kai/routes/load_analysis_report.py | 5 +- kai/service/incident_store/incident_store.py | 18 ++++-- kai/service/incident_store/sql_types.py | 2 +- .../kai_application/kai_application.py | 7 ++- kai/service/llm_interfacing/model_provider.py | 55 ++++++++++++++++++- kai/service/solution_handling/consumption.py | 3 + kai/service/solution_handling/production.py | 14 +++-- 11 files changed, 96 insertions(+), 22 deletions(-) diff --git a/kai/data/templates/solution_handling/before_and_after.jinja b/kai/data/templates/solution_handling/before_and_after.jinja index 2b7b34bc..cd5f5911 100644 --- a/kai/data/templates/solution_handling/before_and_after.jinja +++ b/kai/data/templates/solution_handling/before_and_after.jinja @@ -1,9 +1,9 @@ -Solution before changes: +Solved example before changes: ``` {{ solution.original_code }} ``` -Solution after changes: +Solved example after changes: ``` {{ solution.updated_code }} ``` diff --git a/kai/data/templates/solution_handling/diff_only.jinja b/kai/data/templates/solution_handling/diff_only.jinja index 72b5ddce..d3d47b72 100644 --- a/kai/data/templates/solution_handling/diff_only.jinja +++ b/kai/data/templates/solution_handling/diff_only.jinja @@ -1,3 +1,3 @@ -Solution diff: +Solved example diff: ```diff {{ solution.file_diff }} diff --git a/kai/data/templates/solution_handling/llm_summary.jinja b/kai/data/templates/solution_handling/llm_summary.jinja index 170e7afd..1afcb8de 100644 --- a/kai/data/templates/solution_handling/llm_summary.jinja +++ b/kai/data/templates/solution_handling/llm_summary.jinja @@ -1,3 +1,3 @@ -Summary of changes for solution: +Summary of changes for solved example: {{ solution.llm_summary }} diff --git a/kai/models/util.py b/kai/models/util.py index 05475e7e..39b8cf5c 100644 --- a/kai/models/util.py +++ b/kai/models/util.py @@ -12,7 +12,11 @@ # These are known unique variables that can be included by incidents # They would prevent matches that we actually want, so we filter them # before adding to the database or searching -FILTERED_INCIDENT_VARS = ("file", "package") +FILTERED_INCIDENT_VARS = [ + "file", # Java, URI of the offending file + "package", # Java, shows the package + "name", # Java, shows the name of the method that caused the incident +] def remove_known_prefixes(path: str) -> str: diff --git a/kai/routes/load_analysis_report.py b/kai/routes/load_analysis_report.py index 62bc45b3..256cfab7 100644 --- a/kai/routes/load_analysis_report.py +++ b/kai/routes/load_analysis_report.py @@ -21,8 +21,9 @@ class PostLoadAnalysisReportApplication(BaseModel): class PostLoadAnalysisReportParams(BaseModel): - path_to_report: str application: PostLoadAnalysisReportApplication + report_data: dict | list[dict] + report_id: str @to_route("post", "/load_analysis_report") @@ -30,7 +31,7 @@ async def post_load_analysis_report(request: Request): params = PostLoadAnalysisReportParams.model_validate(await request.json()) application = Application(**params.application.model_dump()) - report = Report.load_report_from_file(params.path_to_report) + report = Report(params.report_data, params.report_id) count = request.app["kai_application"].incident_store.load_report( application, report diff --git a/kai/service/incident_store/incident_store.py b/kai/service/incident_store/incident_store.py index 4ce0feb0..6cde43d0 100644 --- a/kai/service/incident_store/incident_store.py +++ b/kai/service/incident_store/incident_store.py @@ -181,6 +181,8 @@ def __init__( self.solution_detector = solution_detector self.solution_producer = solution_producer + self.create_tables() # This is a no-op if the tables already exist + def load_report(self, app: Application, report: Report) -> tuple[int, int, int]: """ Load incidents from a report and given application object. Returns a @@ -273,6 +275,7 @@ def load_report(self, app: Application, report: Report) -> tuple[int, int, int]: session.commit() for incident in violation_obj.incidents: + filtered_vars = filter_incident_vars(incident.variables) report_incidents.append( SQLIncident( violation_name=violation.violation_name, @@ -281,7 +284,7 @@ def load_report(self, app: Application, report: Report) -> tuple[int, int, int]: incident_uri=incident.uri, incident_snip=incident.code_snip, incident_line=incident.line_number, - incident_variables=deep_sort(incident.variables), + incident_variables=deep_sort(filtered_vars), incident_message=incident.message, ) ) @@ -401,9 +404,7 @@ def find_solutions( ) result: list[Solution] = [] - for incident in session.execute( - select_incidents_with_solutions_stmt - ).scalars(): + for incident in session.scalars(select_incidents_with_solutions_stmt).all(): select_accepted_solution_stmt = select(SQLAcceptedSolution).where( SQLAcceptedSolution.solution_id == incident.solution_id ) @@ -415,9 +416,18 @@ def find_solutions( processed_solution = self.solution_producer.post_process_one( incident, accepted_solution.solution ) + + # TODO: This first line doesn't work for some reason. The second + # line is a hack to get around it. accepted_solution.solution = processed_solution + session.query(SQLAcceptedSolution).filter( + SQLAcceptedSolution.solution_id == incident.solution_id + ).update({"solution": processed_solution}) + result.append(processed_solution) + session.commit() + session.commit() return result diff --git a/kai/service/incident_store/sql_types.py b/kai/service/incident_store/sql_types.py index ed63d84c..bb1bf7ac 100644 --- a/kai/service/incident_store/sql_types.py +++ b/kai/service/incident_store/sql_types.py @@ -23,7 +23,7 @@ class SQLSolutionType(TypeDecorator): impl = VARCHAR - cache_ok = True + cache_ok = False def process_bind_param(self, value: Optional[Solution], dialect: Dialect): # Into the db diff --git a/kai/service/kai_application/kai_application.py b/kai/service/kai_application/kai_application.py index 2815f0d4..eee9e2ef 100644 --- a/kai/service/kai_application/kai_application.py +++ b/kai/service/kai_application/kai_application.py @@ -148,9 +148,10 @@ def get_incident_solutions_for_file( ) if len(solutions) != 0: - pb_incident["solution_str"] = self.solution_consumer( - solutions[0] - ) + solution_str = self.solution_consumer(solutions[0]) + + if len(solution_str) != 0: + pb_incident["solution_str"] = solution_str pb_vars = { "src_file_name": file_name, diff --git a/kai/service/llm_interfacing/model_provider.py b/kai/service/llm_interfacing/model_provider.py index be69fc07..5774c407 100644 --- a/kai/service/llm_interfacing/model_provider.py +++ b/kai/service/llm_interfacing/model_provider.py @@ -1,4 +1,5 @@ import os +from typing import Any from genai import Client, Credentials from genai.extensions.langchain.chat_llm import LangChainChatInterface @@ -100,7 +101,22 @@ def __init__(self, config: KaiConfigModels): model_class = FakeListChatModel defaults = { - "responses": ["Default LLM response."], + "responses": [ + "## Reasoning\n" + "\n" + "Default reasoning.\n" + "\n" + "## Updated File\n" + "\n" + "```\n" + "Default updated file.\n" + "```\n" + "\n" + "## Additional Information\n" + "\n" + "Default additional information.\n" + "\n" + ], "sleep": None, } @@ -137,3 +153,40 @@ def __init__(self, config: KaiConfigModels): ] else: self.llama_header = config.llama_header + + +class Wrapper: + """ + Wrapper class to intercept and log all attribute access and method calls on + an object. + """ + + def __init__(self, obj: Any): + self.obj: Any = obj + self.callable_results: list = [] + + def __getattr__(self, attr: Any): + print(f"Getting {type(self.obj).__name__}.{attr}") + + result = getattr(self.obj, attr) + if callable(result): + return self.CallableWrapper(self, result) + + return result + + class CallableWrapper: + def __init__(self, parent: "Wrapper", callable: Any): + self.parent = parent + self.callable = callable + + def __call__(self, *args, **kwargs): + print(f"Calling {type(self.parent.obj).__name__}.{self.callable.__name__}") + + for i, arg in enumerate(args): + print(f" arg {i}: {arg}") + for key, value in kwargs.items(): + print(f" {key}: {value}") + + result = self.callable(*args, **kwargs) + self.parent.callable_results.append(result) + return result diff --git a/kai/service/solution_handling/consumption.py b/kai/service/solution_handling/consumption.py index 48e63482..4569767d 100644 --- a/kai/service/solution_handling/consumption.py +++ b/kai/service/solution_handling/consumption.py @@ -42,6 +42,9 @@ def solution_consumer_before_and_after(solution: Solution) -> str: def solution_consumer_llm_summary(solution: Solution) -> str: + if solution.llm_summary is None: + return "" + return ( __create_jinja_env().get_template("llm_summary.jinja").render(solution=solution) ) diff --git a/kai/service/solution_handling/production.py b/kai/service/solution_handling/production.py index b2453cdf..34bee2f2 100644 --- a/kai/service/solution_handling/production.py +++ b/kai/service/solution_handling/production.py @@ -65,9 +65,12 @@ class SolutionProducerTextOnly(SolutionProducer): def produce_one( self, incident: SQLIncident, repo: Repo, old_commit: str, new_commit: str ) -> Solution: + local_file_path = remove_known_prefixes( + unquote(urlparse(incident.incident_uri).path) + ) file_path = os.path.join( repo.working_tree_dir, - remove_known_prefixes(unquote(urlparse(incident.incident_uri).path)), + local_file_path, ) # NOTE: `repo_diff` functionality is not implemented @@ -77,7 +80,7 @@ def produce_one( # probably a better way to handle this. try: original_code = ( - repo.git.show(f"{new_commit}:{file_path}") + repo.git.show(f"{old_commit}:{local_file_path}") .encode("utf-8") .decode("utf-8") ) @@ -86,7 +89,7 @@ def produce_one( try: updated_code = ( - repo.git.show(f"{new_commit}:{file_path}") + repo.git.show(f"{new_commit}:{local_file_path}") .encode("utf-8") .decode("utf-8") ) @@ -113,13 +116,12 @@ def post_process_one(self, incident: SQLIncident, solution: Solution) -> Solutio class SolutionProducerLLMLazy(SolutionProducer): def __init__(self, model_provider: ModelProvider): self.model_provider = model_provider + self.text_only = SolutionProducerTextOnly() def produce_one( self, incident: SQLIncident, repo: Repo, old_commit: str, new_commit: str ) -> Solution: - solution = SolutionProducerTextOnly().produce_one( - incident, repo, old_commit, new_commit - ) + solution = self.text_only.produce_one(incident, repo, old_commit, new_commit) solution.llm_summary_generated = False