Skip to content

Commit

Permalink
Update pydantic notebook to output results/traceback to file
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Montleon <jmontleo@redhat.com>
  • Loading branch information
jmontleon committed Mar 1, 2024
1 parent 7fc02b0 commit 86ab59e
Show file tree
Hide file tree
Showing 91 changed files with 9,939 additions and 22 deletions.
3 changes: 3 additions & 0 deletions .trunk/trunk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ lint:
- samples/analysis_reports/**
- samples/generated_output/**
- samples/sample_repos/**
# This file is from https://github.com/rh-aiservices-bu/llm-on-openshift
# It is included here only for convenience
- notebooks/pydantic/jms_to_smallrye_reactive/caikit_tgis_langchain.py
actions:
enabled:
- trunk-announce
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
" formatted_template,\n",
" example_javaee_file_contents,\n",
" example_quarkus_file_contents,\n",
" result,\n",
"):\n",
" try:\n",
" os.makedirs(out_dir, exist_ok=True)\n",
Expand All @@ -74,21 +73,6 @@
" f.write(formatted_template)\n",
" print(f\"Saved template to {template_path}\")\n",
"\n",
" # Save the result\n",
" result_path = os.path.join(out_dir, \"result.txt\")\n",
" with open(result_path, \"w\") as f:\n",
" f.truncate()\n",
" f.write(result)\n",
" print(f\"Saved result to {result_path}\")\n",
"\n",
" # Extract the updated java code and save it\n",
" updated_file_contents = get_java_in_result(result)\n",
" updated_file_path = os.path.join(out_dir, \"updated_file.java\")\n",
" with open(updated_file_path, \"w\") as f:\n",
" f.truncate()\n",
" f.write(updated_file_contents)\n",
" print(f\"Saved updated java file to {updated_file_path}\")\n",
"\n",
" # Save the original source code\n",
" original_file_path = os.path.join(out_dir, \"original_file.java\")\n",
" with open(original_file_path, \"w\") as f:\n",
Expand Down Expand Up @@ -255,6 +239,9 @@
"metadata": {},
"outputs": [],
"source": [
"import caikit_tgis_langchain\n",
"import traceback\n",
"\n",
"from langchain_openai import ChatOpenAI\n",
"from langchain.output_parsers import PydanticOutputParser\n",
"from langchain_core.pydantic_v1 import BaseModel, Field, validator\n",
Expand Down Expand Up @@ -299,36 +286,70 @@
"\n",
"formatted_prompt = template_with_solved_example_and_analysis.format(**template_args)\n",
"\n",
"# For comparison, we will look up what the source file looks like from Quarkus branch\n",
"# this serves as a way of comparing to what the 'answer' is that was done manually by a knowledgeable developer\n",
"src_file_from_quarkus_branch = helloworld_diff.get_file_contents(\n",
" src_file_name, \"quarkus\"\n",
")\n",
"\n",
"### Choose one of the options below to run against\n",
"\n",
"## OpenAI: Specify a model and make sure OPENAI_API_KEY is exported\n",
"# model_name = \"gpt-3.5-turbo\"\n",
"# model_name = \"gpt-4-1106-preview\"\n",
"# provider = \"openai\"\n",
"# model = \"gpt-3.5-turbo\"\n",
"# model = \"gpt-4-1106-preview\"\n",
"# llm = ChatOpenAI(temperature=0.1,\n",
"# model_name=model_name,\n",
"# model_name=model,\n",
"# streaming=True)\n",
"\n",
"## Oobabooga: text-gen uses the OpenAI API so it works similarly.\n",
"## We load a model and specify temperature on the server so we don't need to specify them.\n",
"## We've had decent results with Mistral-7b and Codellama-34b-Instruct\n",
"## Just the URL for the server and make sure OPENAI_API_KEY is exported\n",
"# provider = \"text-gen\"\n",
"# model = \"CodeLlama-7b-Instruct-hf\"\n",
"# model = \"CodeLlama-13b-Instruct-hf\"\n",
"# model = \"CodeLlama-34b-Instruct-hf\"\n",
"# model = \"CodeLlama-70b-Instruct-hf\"\n",
"# model = \"starcoder\"\n",
"# model = \"starcoder2-3b\"\n",
"# model = \"starcoder2-7b\"\n",
"# model = \"starcoder2-15b\"\n",
"# model = \"WizardCoder-15B-V1.0\"\n",
"# model = \"Mistral-7B-v0.1\"\n",
"# model = \"Mixtral-8x7B-v0.1\"\n",
"# llm = ChatOpenAI(openai_api_base=\"https://text-gen-api-text-gen.apps.ai.migration.redhat.com/v1\",\n",
"# streaming=True)\n",
"# temperature=0.1,\n",
"# streaming=True)\n",
"\n",
"## OpenShift AI:\n",
"## We need to make sure caikit-nlp-client is installed.\n",
"## As well as caikit_tgis_langchain.py: https://github.com/rh-aiservices-bu/llm-on-openshift/blob/main/examples/notebooks/langchain/caikit_tgis_langchain.py\n",
"## Then set the model_id and server url\n",
"## But we are having issues with responses: https://github.com/opendatahub-io/caikit-nlp-client/issues/95\n",
"# provider = \"openshift-ai\"\n",
"# model = \"codellama-7b-hf\"\n",
"# llm = caikit_tgis_langchain.CaikitLLM(inference_server_url=\"https://codellama-7b-hf-predictor-kyma-workshop.apps.ai.migration.redhat.com\",\n",
"# model_id=\"CodeLlama-7b-hf\",\n",
"# streaming=True)\n",
"\n",
"## IBM Models:\n",
"## Ensure GENAI_KEY is exported. Change the model if desired.\n",
"# provider = \"ibm\"\n",
"# model = \"ibm/granite-13b-instruct-v1\"\n",
"# model = \"ibm/granite-13b-instruct-v2\"\n",
"# model = \"ibm/granite-20b-5lang-instruct-rc\"\n",
"# model = \"ibm/granite-20b-code-instruct-v1\"\n",
"# model = \"ibm/granite-20b-code-instruct-v1-gptq\"\n",
"# model = \"ibm-mistralai/mixtral-8x7b-instruct-v01-q\"\n",
"# model = \"codellama/codellama-34b-instruct\"\n",
"# model = \"codellama/codellama-70b-instruct\"\n",
"# model = \"kaist-ai/prometheus-13b-v1\"\n",
"# model = \"mistralai/mistral-7b-instruct-v0-2\"\n",
"# model = \"thebloke/mixtral-8x7b-v0-1-gptq\"\n",
"# llm = LangChainChatInterface(\n",
"# client=Client(credentials=Credentials.from_env()),\n",
"# model_id=\"ibm/granite-20b-code-instruct-v1-gptq\",\n",
"# model_id=model,\n",
"# parameters=TextGenerationParameters(\n",
"# decoding_method=DecodingMethod.SAMPLE,\n",
"# max_new_tokens=4096,\n",
Expand All @@ -347,7 +368,35 @@
"\n",
"prompt = PromptTemplate.from_template(template_with_solved_example_and_analysis)\n",
"chain = LLMChain(llm=llm, prompt=prompt, output_parser=response_parser)\n",
"result = chain.invoke(template_args)"
"\n",
"model = model.replace(\"/\", \"_\")\n",
"output_dir = \"output/\" + provider + \"_\" + model + \"/few_shot_pydantic/\"\n",
"write_output_to_disk(\n",
" output_dir,\n",
" formatted_prompt,\n",
" src_file_contents,\n",
" src_file_from_quarkus_branch,\n",
")\n",
"\n",
"try:\n",
" result = chain.invoke(template_args)\n",
" with open(output_dir + \"result.txt\", \"a\") as f:\n",
" f.write(\"### Reasoning:\\n\")\n",
" for i in range(len(result[\"text\"].reasoning)):\n",
" f.write(result[\"text\"].reasoning[i])\n",
" f.write(\"\\n\")\n",
" for i in range(len(result[\"text\"].updated_files)):\n",
" f.write(\"### Updated file \" + str((i + 1)) + \"\\n\")\n",
" f.write(result[\"text\"].updated_files[i].file_name + \":\")\n",
" f.write(result[\"text\"].updated_files[i].file_contents)\n",
" print(\"Saved result in \" + output_dir + \"result.txt\")\n",
"except Exception as e:\n",
" with open(output_dir + \"traceback.txt\", \"a\") as f:\n",
" f.write(str(e))\n",
" f.write(traceback.format_exc())\n",
" print(\n",
" \"Something went wrong. A traceback was saved in \" + output_dir + \"traceback.txt\"\n",
" )"
]
},
{
Expand Down
131 changes: 131 additions & 0 deletions notebooks/pydantic/jms_to_smallrye_reactive/caikit_tgis_langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import Any, Iterator, List, Mapping, Optional, Union
from warnings import warn

from caikit_nlp_client import GrpcClient, HttpClient
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.schema.output import GenerationChunk


class CaikitLLM(LLM):
inference_server_url: str
model_id: str
certificate_chain: Optional[str] = None
streaming: bool
client: HttpClient

def __init__(
self,
inference_server_url: str,
model_id: str,
certificate_chain: Optional[str] = None,
streaming: bool = False,
client: Optional[GrpcClient] = HttpClient(""),
):
super().__init__(
inference_server_url=inference_server_url,
model_id=model_id,
certificate_chain=certificate_chain,
streaming=streaming,
client=client,
)

self.inference_server_url = inference_server_url
self.model_id = model_id

if certificate_chain:
with open(certificate_chain, "rb") as fh:
chain = fh.read()
else:
chain = None

if inference_server_url.startswith("http"):
client = HttpClient(inference_server_url)
else:
try:
host, port = inference_server_url.split(":")
if not all((host, port)):
raise ValueError
except ValueError:
raise ValueError(
"Invalid url provided, must be either "
'"host:port" or "http[s]://host:port/path"'
)

client = GrpcClient(host, port, ca_cert=chain)

self.client: Union[HttpClient, GrpcClient] = client

@property
def _llm_type(self) -> str:
return "caikit_tgis"

def _call(
self,
prompt: str,
preserve_input_text: bool = False,
max_new_tokens: int = 512,
min_new_tokens: int = 10,
device: str = "",
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.streaming:
return "".join(
self._stream(
prompt=prompt,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
device=device,
stop=stop,
run_manager=run_manager,
**kwargs,
)
)
if run_manager:
warn("run_manager is ignored for non-streaming use cases")

if device or stop:
raise NotImplementedError()

return self.client.generate_text(
self.model_id,
prompt,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
)

def _stream(
self,
prompt: str,
preserve_input_text: bool = False,
max_new_tokens: int = 512,
min_new_tokens: int = 10,
device: str = "",
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
if device or stop:
raise NotImplementedError

for token in self.client.generate_text_stream(
self.model_id,
prompt,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
):
chunk = GenerationChunk(text=token)
yield chunk.text

if run_manager:
run_manager.on_llm_new_token(chunk.text)

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"inference_server_url": self.inference_server_url}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* JBoss, Home of Professional Open Source
* Copyright 2015, Red Hat, Inc. and/or its affiliates, and individual
* contributors by the @authors tag. See the copyright.txt in the
* distribution for a full listing of individual contributors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jboss.as.quickstarts.mdb;

import java.util.logging.Logger;
import javax.ejb.ActivationConfigProperty;
import javax.ejb.MessageDriven;
import javax.jms.JMSException;
import javax.jms.Message;
import javax.jms.MessageListener;
import javax.jms.TextMessage;

/**
* <p>
* A simple Message Driven Bean that asynchronously receives and processes the messages that are sent to the queue.
* </p>
*
* @author Serge Pagop (spagop@redhat.com)
*/
@MessageDriven(name = "HelloWorldQueueMDB", activationConfig = {
@ActivationConfigProperty(propertyName = "destinationLookup", propertyValue = "queue/HELLOWORLDMDBQueue"),
@ActivationConfigProperty(propertyName = "destinationType", propertyValue = "javax.jms.Queue"),
@ActivationConfigProperty(propertyName = "acknowledgeMode", propertyValue = "Auto-acknowledge")})
public class HelloWorldQueueMDB implements MessageListener {

private static final Logger LOGGER = Logger.getLogger(HelloWorldQueueMDB.class.toString());

/**
* @see MessageListener#onMessage(Message)
*/
public void onMessage(Message rcvMessage) {
TextMessage msg = null;
try {
if (rcvMessage instanceof TextMessage) {
msg = (TextMessage) rcvMessage;
LOGGER.info("Received Message from queue: " + msg.getText());
} else {
LOGGER.warning("Message of wrong type: " + rcvMessage.getClass().getName());
}
} catch (JMSException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* JBoss, Home of Professional Open Source
* Copyright 2015, Red Hat, Inc. and/or its affiliates, and individual
* contributors by the @authors tag. See the copyright.txt in the
* distribution for a full listing of individual contributors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jboss.as.quickstarts.mdb;

import jakarta.enterprise.context.ApplicationScoped;
import io.smallrye.reactive.messaging.annotations.Merge;
import org.jboss.logging.Logger;

import org.eclipse.microprofile.reactive.messaging.Incoming;

@ApplicationScoped
public class HelloWorldQueueMDB {
private static final Logger LOGGER = Logger.getLogger(HelloWorldQueueMDB.class.toString());

@Incoming("HELLOWORLDMDBQueue")
@Merge
public void onMessage(String message) {
LOGGER.info("Received Message from queue: " + message);
}
}
Loading

0 comments on commit 86ab59e

Please sign in to comment.