Skip to content

Commit

Permalink
Update pydantic notebook to output results/traceback to file
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Montleon <jmontleo@redhat.com>
  • Loading branch information
jmontleon committed Mar 1, 2024
1 parent 7fc02b0 commit a80b85f
Show file tree
Hide file tree
Showing 152 changed files with 17,324 additions and 213 deletions.
4 changes: 4 additions & 0 deletions .trunk/trunk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ lint:
- samples/analysis_reports/**
- samples/generated_output/**
- samples/sample_repos/**
# This file is from https://github.com/rh-aiservices-bu/llm-on-openshift
# It is included here only for convenience
- notebooks/jms_to_smallrye_reactive/caikit_tgis_langchain.py
- notebooks/pydantic/jms_to_smallrye_reactive/caikit_tgis_langchain.py
actions:
enabled:
- trunk-announce
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Obtaining file:///Users/jmatthews/git/jwmatthews/kai\n",
" Preparing metadata (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25hInstalling collected packages: kai\n",
" Attempting uninstall: kai\n",
" Found existing installation: kai 0.0.1\n",
" Uninstalling kai-0.0.1:\n",
" Successfully uninstalled kai-0.0.1\n",
" Running setup.py develop for kai\n",
"Successfully installed kai-0.0.1\n"
]
}
],
"outputs": [],
"source": [
"# Install local kai package in the current Jupyter kernel\n",
"import sys\n",
Expand All @@ -50,7 +34,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -130,7 +114,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -158,6 +142,7 @@
" \n",
" After you have shared your step by step thinking, provide a full output of the updated file:\n",
"\n",
"### Input:\n",
" # Input information\n",
" ## Issue found from static code analysis of the Java EE code which needs to be fixed to migrate to Quarkus\n",
" Issue to fix: \"{analysis_message}\"\n",
Expand Down Expand Up @@ -207,18 +192,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reading report from ./analysis_report/cmt/output.yaml\n",
"Reading report from ./analysis_report/helloworld-mdb-quarkus/output.yaml\n"
]
}
],
"outputs": [],
"source": [
"from kai.report import Report\n",
"from kai.scm import GitDiff\n",
Expand Down Expand Up @@ -289,22 +265,25 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved template to output/gpt-4-1106-preview/helloworldmdb/custom-ruleset/jms-to-reactive-quarkus-00010/few_shot/template.txt\n",
"Saved result to output/gpt-4-1106-preview/helloworldmdb/custom-ruleset/jms-to-reactive-quarkus-00010/few_shot/result.txt\n",
"Saved updated java file to output/gpt-4-1106-preview/helloworldmdb/custom-ruleset/jms-to-reactive-quarkus-00010/few_shot/updated_file.java\n",
"Saved original java file to output/gpt-4-1106-preview/helloworldmdb/custom-ruleset/jms-to-reactive-quarkus-00010/few_shot/original_file.java\n",
"Saved Quarkus version from Git to output/gpt-4-1106-preview/helloworldmdb/custom-ruleset/jms-to-reactive-quarkus-00010/few_shot/quarkus_version_from_git.java\n"
]
}
],
"outputs": [],
"source": [
"import caikit_tgis_langchain\n",
"\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"from langchain_core.messages import HumanMessage, SystemMessage\n",
"from genai import Client, Credentials\n",
"from genai.extensions.langchain.chat_llm import LangChainChatInterface\n",
"from genai.schema import (\n",
" DecodingMethod,\n",
" ModerationHAP,\n",
" ModerationParameters,\n",
" TextGenerationParameters,\n",
" TextGenerationReturnOptions,\n",
")\n",
"\n",
"# For first run we will only include the source file and analysis info (we are omitting solved example)\n",
"template_args = {\n",
" \"src_file_name\": src_file_name,\n",
Expand All @@ -317,10 +296,82 @@
"}\n",
"formatted_prompt = template_with_solved_example_and_analysis.format(**template_args)\n",
"\n",
"# model_name = \"gpt-4-1106-preview\"\n",
"model_name = \"gpt-3.5-turbo-16k\"\n",
"### Choose one of the options below to run against\n",
"### Uncomment the provider, one model, and the llm\n",
"\n",
"## OpenAI: Specify a model and make sure OPENAI_API_KEY is exported\n",
"provider = \"openai\"\n",
"# model = \"gpt-3.5-turbo\"\n",
"# model = \"gpt-3.5-turbo-16k\"\n",
"# model = \"gpt-4-1106-preview\"\n",
"# llm = ChatOpenAI(temperature=0.1,\n",
"# model_name=model,\n",
"# streaming=True)\n",
"\n",
"## Oobabooga: text-gen uses the OpenAI API so it works similarly.\n",
"## We load a model and specify temperature on the server so we don't need to specify them.\n",
"## We've had decent results with Mistral-7b and Codellama-34b-Instruct\n",
"## Just the URL for the server and make sure OPENAI_API_KEY is exported\n",
"# provider = \"text-gen\"\n",
"# model = \"CodeLlama-7b-Instruct-hf\"\n",
"# model = \"CodeLlama-13b-Instruct-hf\"\n",
"# model = \"CodeLlama-34b-Instruct-hf\"\n",
"# model = \"CodeLlama-70b-Instruct-hf\"\n",
"# model = \"starcoder\"\n",
"# model = \"starcoder2-3b\"\n",
"# model = \"starcoder2-7b\"\n",
"# model = \"starcoder2-15b\"\n",
"# model = \"WizardCoder-15B-V1.0\"\n",
"# model = \"Mistral-7B-v0.1\"\n",
"# model = \"Mixtral-8x7B-v0.1\"\n",
"# llm = ChatOpenAI(openai_api_base=\"https://text-gen-api-text-gen.apps.ai.migration.redhat.com/v1\",\n",
"# temperature=0.1,\n",
"# streaming=True)\n",
"\n",
"## OpenShift AI:\n",
"## We need to make sure caikit-nlp-client is installed.\n",
"## As well as caikit_tgis_langchain.py: https://github.com/rh-aiservices-bu/llm-on-openshift/blob/main/examples/notebooks/langchain/caikit_tgis_langchain.py\n",
"## Then set the model_id and server url\n",
"## But we are having issues with responses: https://github.com/opendatahub-io/caikit-nlp-client/issues/95\n",
"# provider = \"openshift-ai\"\n",
"# model = \"codellama-7b-hf\"\n",
"# llm = caikit_tgis_langchain.CaikitLLM(inference_server_url=\"https://codellama-7b-hf-predictor-kyma-workshop.apps.ai.migration.redhat.com\",\n",
"# model_id=\"CodeLlama-7b-hf\",\n",
"# streaming=True)\n",
"\n",
"## IBM Models:\n",
"## Ensure GENAI_KEY is exported. Change the model if desired.\n",
"# provider = \"ibm\"\n",
"# model = \"ibm/granite-13b-instruct-v1\"\n",
"# model = \"ibm/granite-13b-instruct-v2\"\n",
"# model = \"ibm/granite-20b-5lang-instruct-rc\"\n",
"# model = \"ibm/granite-20b-code-instruct-v1\"\n",
"# model = \"ibm/granite-20b-code-instruct-v1-gptq\"\n",
"# model = \"ibm-mistralai/mixtral-8x7b-instruct-v01-q\"\n",
"# model = \"codellama/codellama-34b-instruct\"\n",
"# model = \"codellama/codellama-70b-instruct\"\n",
"# model = \"kaist-ai/prometheus-13b-v1\"\n",
"# model = \"mistralai/mistral-7b-instruct-v0-2\"\n",
"# model = \"thebloke/mixtral-8x7b-v0-1-gptq\"\n",
"# llm = LangChainChatInterface(\n",
"# client=Client(credentials=Credentials.from_env()),\n",
"# model_id=model,\n",
"# parameters=TextGenerationParameters(\n",
"# decoding_method=DecodingMethod.SAMPLE,\n",
"# max_new_tokens=4096,\n",
"# min_new_tokens=10,\n",
"# temperature=0.1,\n",
"# top_k=50,\n",
"# top_p=1,\n",
"# return_options=TextGenerationReturnOptions(input_text=False, input_tokens=True),\n",
"# ),\n",
"# moderations=ModerationParameters(\n",
"# # Threshold is set to very low level to flag everything (testing purposes)\n",
"# # or set to True to enable HAP with default settings\n",
"# hap=ModerationHAP(input=True, output=False, threshold=0.01)\n",
"# ),\n",
"# )\n",
"\n",
"llm = ChatOpenAI(temperature=0.1, model_name=model_name)\n",
"prompt = PromptTemplate.from_template(template_with_solved_example_and_analysis)\n",
"chain = LLMChain(llm=llm, prompt=prompt)\n",
"result = chain.run(template_args)\n",
Expand All @@ -331,7 +382,8 @@
" src_file_name, \"quarkus\"\n",
")\n",
"\n",
"output_dir = f\"output/{model_name}/helloworldmdb/{ruleset_name}/{rule}/few_shot/\"\n",
"model = model.replace(\"/\", \"_\")\n",
"output_dir = f\"output/{provider}_{model}/helloworldmdb/{ruleset_name}/{rule}/few_shot/\"\n",
"write_output_to_disk(\n",
" output_dir,\n",
" formatted_prompt,\n",
Expand All @@ -358,7 +410,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.12.2"
}
},
"nbformat": 4,
Expand Down
131 changes: 131 additions & 0 deletions notebooks/jms_to_smallrye_reactive/caikit_tgis_langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import Any, Iterator, List, Mapping, Optional, Union
from warnings import warn

from caikit_nlp_client import GrpcClient, HttpClient
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.schema.output import GenerationChunk


class CaikitLLM(LLM):
inference_server_url: str
model_id: str
certificate_chain: Optional[str] = None
streaming: bool
client: HttpClient

def __init__(
self,
inference_server_url: str,
model_id: str,
certificate_chain: Optional[str] = None,
streaming: bool = False,
client: Optional[GrpcClient] = HttpClient(""),
):
super().__init__(
inference_server_url=inference_server_url,
model_id=model_id,
certificate_chain=certificate_chain,
streaming=streaming,
client=client,
)

self.inference_server_url = inference_server_url
self.model_id = model_id

if certificate_chain:
with open(certificate_chain, "rb") as fh:
chain = fh.read()
else:
chain = None

if inference_server_url.startswith("http"):
client = HttpClient(inference_server_url)
else:
try:
host, port = inference_server_url.split(":")
if not all((host, port)):
raise ValueError
except ValueError:
raise ValueError(
"Invalid url provided, must be either "
'"host:port" or "http[s]://host:port/path"'
)

client = GrpcClient(host, port, ca_cert=chain)

self.client: Union[HttpClient, GrpcClient] = client

@property
def _llm_type(self) -> str:
return "caikit_tgis"

def _call(
self,
prompt: str,
preserve_input_text: bool = False,
max_new_tokens: int = 512,
min_new_tokens: int = 10,
device: str = "",
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.streaming:
return "".join(
self._stream(
prompt=prompt,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
device=device,
stop=stop,
run_manager=run_manager,
**kwargs,
)
)
if run_manager:
warn("run_manager is ignored for non-streaming use cases")

if device or stop:
raise NotImplementedError()

return self.client.generate_text(
self.model_id,
prompt,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
)

def _stream(
self,
prompt: str,
preserve_input_text: bool = False,
max_new_tokens: int = 512,
min_new_tokens: int = 10,
device: str = "",
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
if device or stop:
raise NotImplementedError

for token in self.client.generate_text_stream(
self.model_id,
prompt,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
):
chunk = GenerationChunk(text=token)
yield chunk.text

if run_manager:
run_manager.on_llm_new_token(chunk.text)

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"inference_server_url": self.inference_server_url}
Loading

0 comments on commit a80b85f

Please sign in to comment.