Skip to content

Commit

Permalink
Add an example using pydantic to parse structured data
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Montleon <jmontleo@redhat.com>
  • Loading branch information
jmontleon committed Feb 27, 2024
1 parent cb85418 commit 9220731
Show file tree
Hide file tree
Showing 2 changed files with 378 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,371 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Few Shot: JMS -> SmallRye Reactive - jms-to-reactive-quarkus-00010\n",
"\n",
"> Work with a LLM to generate a fix for the rule: jms-to-reactive-quarkus-00010\n",
"\n",
"We will include the solved example in this example to see if we get better quality results\n",
"\n",
"##### Sample Applications Used\n",
"* 2 sample apps from [JBoss EAP Quickstarts](https://github.com/jboss-developer/jboss-eap-quickstarts/tree/7.4.x) were chosen: helloworld-mdb & cmt\n",
" * [helloworld-mdb](https://github.com/savitharaghunathan/helloworld-mdb)\n",
" * [cmt](https://github.com/konveyor-ecosystem/cmt)\n",
"\n",
"##### Using Custom Rules for JMS to SmallRye Reactive\n",
"* Rules were developed by Juanma [@jmle](https://github.com/jmle)\n",
" * Rules originally from: https://github.com/jmle/rulesets/blob/jms-rule/default/generated/quarkus/05-jms-to-reactive-quarkus.windup.yaml\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install local kai package in the current Jupyter kernel\n",
"import sys\n",
"\n",
"!{sys.executable} -m pip install -e ../../"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Bring in the common deps\n",
"import os\n",
"import pprint\n",
"\n",
"from langchain import PromptTemplate\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.chains import LLMChain\n",
"\n",
"pp = pprint.PrettyPrinter(indent=2)\n",
"\n",
"\n",
"def get_java_in_result(text: str):\n",
" _, after = text.split(\"```java\")\n",
" return after.split(\"```\")[0]\n",
"\n",
"\n",
"def write_output_to_disk(\n",
" out_dir,\n",
" formatted_template,\n",
" example_javaee_file_contents,\n",
" example_quarkus_file_contents,\n",
" result,\n",
"):\n",
" try:\n",
" os.makedirs(out_dir, exist_ok=True)\n",
" except OSError as error:\n",
" print(f\"Error creating directory {out_dir}: {error}\")\n",
" raise error\n",
"\n",
" # Save the template to disk\n",
" template_path = os.path.join(out_dir, \"template.txt\")\n",
" with open(template_path, \"w\") as f:\n",
" f.truncate()\n",
" f.write(formatted_template)\n",
" print(f\"Saved template to {template_path}\")\n",
"\n",
" # Save the result\n",
" result_path = os.path.join(out_dir, \"result.txt\")\n",
" with open(result_path, \"w\") as f:\n",
" f.truncate()\n",
" f.write(result)\n",
" print(f\"Saved result to {result_path}\")\n",
"\n",
" # Extract the updated java code and save it\n",
" updated_file_contents = get_java_in_result(result)\n",
" updated_file_path = os.path.join(out_dir, \"updated_file.java\")\n",
" with open(updated_file_path, \"w\") as f:\n",
" f.truncate()\n",
" f.write(updated_file_contents)\n",
" print(f\"Saved updated java file to {updated_file_path}\")\n",
"\n",
" # Save the original source code\n",
" original_file_path = os.path.join(out_dir, \"original_file.java\")\n",
" with open(original_file_path, \"w\") as f:\n",
" f.truncate()\n",
" f.write(example_javaee_file_contents)\n",
" print(f\"Saved original java file to {original_file_path}\")\n",
"\n",
" # Save the Quarkus version we already in Git to use as a comparison\n",
" known_quarkus_file_path = os.path.join(out_dir, \"quarkus_version_from_git.java\")\n",
" with open(known_quarkus_file_path, \"w\") as f:\n",
" f.truncate()\n",
" f.write(example_quarkus_file_contents)\n",
" print(f\"Saved Quarkus version from Git to {known_quarkus_file_path}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create a Prompt\n",
"## Step #1: Create a Prompt Template\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"template_with_solved_example_and_analysis = \"\"\"\n",
" # Java EE to Quarkus Migration\n",
" You are an AI Assistant trained on migrating enterprise JavaEE code to Quarkus.\n",
" I will give you an example of a JavaEE file and you will give me the Quarkus equivalent.\n",
"\n",
" To help you update this file to Quarkus I will provide you with static source code analysis information\n",
" highlighting an issue which needs to be addressed, I will also provide you with an example of how a similar\n",
" issue was solved in the past via a solved example. You can refer to the solved example for a pattern of\n",
" how to update the input Java EE file to Quarkus.\n",
"\n",
" Be sure to pay attention to the issue found from static analysis and treat it as the primary issue you must \n",
" address or explain why you are unable to.\n",
"\n",
" Approach this code migration from Java EE to Quarkus as if you were an experienced enterprise Java EE developer.\n",
" Before attempting to migrate the code to Quarkus, explain each step of your reasoning through what changes \n",
" are required and why. \n",
"\n",
" Pay attention to changes you make and impacts to external dependencies in the pom.xml as well as changes \n",
" to imports we need to consider.\n",
"\n",
" As you make changes that impact the pom.xml or imports, be sure you explain what needs to be updated.\n",
" \n",
" {format_instructions}\n",
" \n",
"### Input:\n",
" ## Issue found from static code analysis of the Java EE code which needs to be fixed to migrate to Quarkus\n",
" Issue to fix: \"{analysis_message}\"\n",
"\n",
" ## Solved Example Filename\n",
" Filename: \"{solved_example_file_name}\"\n",
"\n",
" ## Solved Example Git Diff \n",
" This diff of the solved example shows what changes we made in past to address a similar problem.\n",
" Please consider this heavily in your response.\n",
" ```diff\n",
" {solved_example_diff}\n",
" ```\n",
"\n",
" ## Input file name\n",
" Filename: \"{src_file_name}\"\n",
"\n",
" ## Input Line number of the issue first appearing in the Java EE code source code example below\n",
" Line number: {analysis_lineNumber}\n",
" \n",
" ## Input source code file contents for \"{src_file_name}\"\n",
" ```java \n",
" {src_file_contents}\n",
" ```\n",
" \"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Step 2: Gather the data we want to inject into the prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from kai.report import Report\n",
"from kai.scm import GitDiff\n",
"\n",
"# Static code analysis was run prior and committed to this repo\n",
"path_cmt_analysis = \"./analysis_report/cmt/output.yaml\"\n",
"path_helloworld_analysis = \"./analysis_report/helloworld-mdb-quarkus/output.yaml\"\n",
"\n",
"cmt_report = Report(path_cmt_analysis)\n",
"helloworld_report = Report(path_helloworld_analysis)\n",
"\n",
"# We are limiting our experiment to a single rule for this run\n",
"ruleset_name = \"custom-ruleset\"\n",
"rule = \"jms-to-reactive-quarkus-00010\"\n",
"cmt_rule_data = cmt_report.report[ruleset_name][\"violations\"][rule]\n",
"helloworld_rule_data = helloworld_report.report[ruleset_name][\"violations\"][rule]\n",
"\n",
"# We are expecting to find 1 impacted file in CMT and 2 impacted files in HelloWorld\n",
"assert len(cmt_rule_data[\"incidents\"]), 1\n",
"assert len(helloworld_rule_data[\"incidents\"]), 2\n",
"\n",
"# Setup access to look at the Git repo source code for each sample app\n",
"cmt_src_path = \"../../samples/sample_repos/cmt\"\n",
"cmt_diff = GitDiff(cmt_src_path)\n",
"\n",
"# Ensure we checked out the 'quarkus' branch of the CMT repo\n",
"cmt_diff.checkout_branch(\"quarkus\")\n",
"\n",
"helloworld_src_path = \"../../samples/sample_repos/helloworld-mdb-quarkus\"\n",
"helloworld_diff = GitDiff(helloworld_src_path)\n",
"\n",
"# We want to be sure the the 'quarkus' branch of both repos has been checked out\n",
"# so it's available to consult\n",
"cmt_diff.checkout_branch(\"quarkus\")\n",
"helloworld_diff.checkout_branch(\"quarkus\")\n",
"# Resetting to 'main' for HelloWorld as that represents th initial state of the repo\n",
"helloworld_diff.checkout_branch(\"main\")\n",
"\n",
"# Now we can extract the info we need\n",
"## We will assume:\n",
"## . HelloWorld will be our source application to migrate, so we will use it's 'main' branch which maps to Java EE\n",
"## . CMT will serve as our solved example, so we will consult it's 'quarkus' branch\n",
"\n",
"hw_entry = helloworld_rule_data[\"incidents\"][0]\n",
"src_file_name = helloworld_report.get_cleaned_file_path(hw_entry[\"uri\"])\n",
"src_file_contents = helloworld_diff.get_file_contents(src_file_name, \"main\")\n",
"analysis_message = hw_entry[\"message\"]\n",
"analysis_lineNumber = hw_entry[\"lineNumber\"]\n",
"\n",
"cmt_entry = cmt_rule_data[\"incidents\"][0]\n",
"solved_example_file_name = cmt_report.get_cleaned_file_path(cmt_entry[\"uri\"])\n",
"# solved_file_contents = cmt_diff.get_file_contents(solved_example_file_name, \"quarkus\")\n",
"\n",
"start_commit_id = cmt_diff.get_commit_from_branch(\"main\").hexsha\n",
"end_commit_id = cmt_diff.get_commit_from_branch(\"quarkus\").hexsha\n",
"\n",
"solved_example_diff = cmt_diff.get_patch_for_file(\n",
" start_commit_id, end_commit_id, solved_example_file_name\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Step 3: Run the prompt through the LLM and write the output to disk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import ChatOpenAI\n",
"from langchain.output_parsers import PydanticOutputParser\n",
"from langchain_core.pydantic_v1 import BaseModel, Field, validator\n",
"from typing import List\n",
"\n",
"\n",
"class UpdatedFile(BaseModel):\n",
" file_name: str = Field(description=\"File name of udpated file\")\n",
" file_contents: str = Field(description=\"Contents of the updated file\")\n",
"\n",
"\n",
"class Response(BaseModel):\n",
" reasoning: List[str] = Field(description=\"Process Explanation\")\n",
" updated_files: List[UpdatedFile] = Field(\n",
" description=\"List containing updated files\"\n",
" )\n",
"\n",
"\n",
"response_parser = PydanticOutputParser(pydantic_object=Response)\n",
"\n",
"# For first run we will only include the source file and analysis info (we are omitting solved example)\n",
"template_args = {\n",
" \"src_file_name\": src_file_name,\n",
" \"src_file_contents\": src_file_contents,\n",
" \"analysis_message\": analysis_message,\n",
" \"analysis_lineNumber\": analysis_lineNumber,\n",
" \"solved_example_file_name\": solved_example_file_name,\n",
" \"solved_example_diff\": solved_example_diff,\n",
" \"format_instructions\": response_parser.get_format_instructions(),\n",
"}\n",
"\n",
"formatted_prompt = template_with_solved_example_and_analysis.format(**template_args)\n",
"\n",
"### Choose one of the options below to run against\n",
"\n",
"## OpenAI: Specify a model and make sure OPENAI_API_KEY is exported\n",
"# model_name = \"gpt-3.5-turbo\"\n",
"# model_name = \"gpt-4-1106-preview\"\n",
"# llm = ChatOpenAI(temperature=0.1,\n",
"# model_name=model_name,\n",
"# streaming=True)\n",
"\n",
"## Oobabooga: text-gen uses the OpenAI API so it works similarly.\n",
"## We load a model and specify temperature on the server so we don't need to specify them.\n",
"## We've had decent results with Mistral-7b and Codellama-34b-Instruct\n",
"## Just the URL for the server and make sure OPENAI_API_KEY is exported\n",
"# llm = ChatOpenAI(openai_api_base=\"https://text-gen-api-text-gen.apps.ai.migration.redhat.com/v1\",\n",
"# streaming=True)\n",
"\n",
"## OpenShift AI:\n",
"## We need to make sure caikit-nlp-client is installed.\n",
"## As well as caikit_tgis_langchain.py: https://github.com/rh-aiservices-bu/llm-on-openshift/blob/main/examples/notebooks/langchain/caikit_tgis_langchain.py\n",
"## Then set the model_id and server url\n",
"## But we are having issues with responses: https://github.com/opendatahub-io/caikit-nlp-client/issues/95\n",
"# llm = caikit_tgis_langchain.CaikitLLM(inference_server_url=\"https://codellama-7b-hf-predictor-kyma-workshop.apps.ai.migration.redhat.com\",\n",
"# model_id=\"CodeLlama-7b-hf\",\n",
"# streaming=True)\n",
"\n",
"## IBM Models:\n",
"# TODO\n",
"\n",
"prompt = PromptTemplate.from_template(template_with_solved_example_and_analysis)\n",
"chain = LLMChain(llm=llm, prompt=prompt, output_parser=response_parser)\n",
"result = chain.invoke(template_args)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Step 4: Print results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"### Reasoning:\\n\")\n",
"for i in range(len(result[\"text\"].reasoning)):\n",
" print(result[\"text\"].reasoning[i])\n",
"\n",
"print(\"\\n\")\n",
"\n",
"for i in range(len(result[\"text\"].updated_files)):\n",
" print(\"### Updated file \" + str((i + 1)) + \"\\n\")\n",
" print(result[\"text\"].updated_files[i].file_name + \":\")\n",
" print(result[\"text\"].updated_files[i].file_contents)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading

0 comments on commit 9220731

Please sign in to comment.