Skip to content

Commit

Permalink
Update pydantic notebook to output results/traceback to file
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Montleon <jmontleo@redhat.com>
  • Loading branch information
jmontleon committed Feb 28, 2024
1 parent 7fc02b0 commit 323de79
Show file tree
Hide file tree
Showing 58 changed files with 7,149 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
" formatted_template,\n",
" example_javaee_file_contents,\n",
" example_quarkus_file_contents,\n",
" result,\n",
"):\n",
" try:\n",
" os.makedirs(out_dir, exist_ok=True)\n",
Expand All @@ -74,21 +73,6 @@
" f.write(formatted_template)\n",
" print(f\"Saved template to {template_path}\")\n",
"\n",
" # Save the result\n",
" result_path = os.path.join(out_dir, \"result.txt\")\n",
" with open(result_path, \"w\") as f:\n",
" f.truncate()\n",
" f.write(result)\n",
" print(f\"Saved result to {result_path}\")\n",
"\n",
" # Extract the updated java code and save it\n",
" updated_file_contents = get_java_in_result(result)\n",
" updated_file_path = os.path.join(out_dir, \"updated_file.java\")\n",
" with open(updated_file_path, \"w\") as f:\n",
" f.truncate()\n",
" f.write(updated_file_contents)\n",
" print(f\"Saved updated java file to {updated_file_path}\")\n",
"\n",
" # Save the original source code\n",
" original_file_path = os.path.join(out_dir, \"original_file.java\")\n",
" with open(original_file_path, \"w\") as f:\n",
Expand Down Expand Up @@ -255,6 +239,9 @@
"metadata": {},
"outputs": [],
"source": [
"import caikit_tgis_langchain\n",
"import traceback\n",
"\n",
"from langchain_openai import ChatOpenAI\n",
"from langchain.output_parsers import PydanticOutputParser\n",
"from langchain_core.pydantic_v1 import BaseModel, Field, validator\n",
Expand Down Expand Up @@ -299,19 +286,35 @@
"\n",
"formatted_prompt = template_with_solved_example_and_analysis.format(**template_args)\n",
"\n",
"# For comparison, we will look up what the source file looks like from Quarkus branch\n",
"# this serves as a way of comparing to what the 'answer' is that was done manually by a knowledgeable developer\n",
"src_file_from_quarkus_branch = helloworld_diff.get_file_contents(\n",
" src_file_name, \"quarkus\"\n",
")\n",
"\n",
"### Choose one of the options below to run against\n",
"\n",
"## OpenAI: Specify a model and make sure OPENAI_API_KEY is exported\n",
"# model_name = \"gpt-3.5-turbo\"\n",
"# model_name = \"gpt-4-1106-preview\"\n",
"# provider = \"openai\"\n",
"# model = \"gpt-3.5-turbo\"\n",
"# model = \"gpt-4-1106-preview\"\n",
"# llm = ChatOpenAI(temperature=0.1,\n",
"# model_name=model_name,\n",
"# model_name=model,\n",
"# streaming=True)\n",
"\n",
"## Oobabooga: text-gen uses the OpenAI API so it works similarly.\n",
"## We load a model and specify temperature on the server so we don't need to specify them.\n",
"## We've had decent results with Mistral-7b and Codellama-34b-Instruct\n",
"## Just the URL for the server and make sure OPENAI_API_KEY is exported\n",
"# provider = \"text-gen\"\n",
"# model = \"CodeLlama-7b-Instruct-hf\"\n",
"# model = \"CodeLlama-13b-Instruct-hf\"\n",
"# model = \"CodeLlama-34b-Instruct-hf\"\n",
"# model = \"CodeLlama-70b-Instruct-hf\"\n",
"# model = \"starcoder\"\n",
"# model = \"WizardCoder-15B-V1.0\"\n",
"# model = \"Mistral-7B-v0.1\"\n",
"# model = \"Mixtral-8x7B-v0.1\"\n",
"# llm = ChatOpenAI(openai_api_base=\"https://text-gen-api-text-gen.apps.ai.migration.redhat.com/v1\",\n",
"# streaming=True)\n",
"\n",
Expand All @@ -320,15 +323,29 @@
"## As well as caikit_tgis_langchain.py: https://github.com/rh-aiservices-bu/llm-on-openshift/blob/main/examples/notebooks/langchain/caikit_tgis_langchain.py\n",
"## Then set the model_id and server url\n",
"## But we are having issues with responses: https://github.com/opendatahub-io/caikit-nlp-client/issues/95\n",
"# provider = \"openshift-ai\"\n",
"# model = \"codellama-7b-hf\"\n",
"# llm = caikit_tgis_langchain.CaikitLLM(inference_server_url=\"https://codellama-7b-hf-predictor-kyma-workshop.apps.ai.migration.redhat.com\",\n",
"# model_id=\"CodeLlama-7b-hf\",\n",
"# streaming=True)\n",
"\n",
"## IBM Models:\n",
"## Ensure GENAI_KEY is exported. Change the model if desired.\n",
"# provider = \"ibm\"\n",
"# model = \"ibm/granite-13b-instruct-v1\"\n",
"# model = \"ibm/granite-13b-instruct-v2\"\n",
"# model = \"ibm/granite-20b-5lang-instruct-rc\"\n",
"# model = \"ibm/granite-20b-code-instruct-v1\"\n",
"# model = \"ibm/granite-20b-code-instruct-v1-gptq\"\n",
"# model = \"ibm-mistralai/mixtral-8x7b-instruct-v01-q\"\n",
"# model = \"codellama/codellama-34b-instruct\"\n",
"# model = \"codellama/codellama-70b-instruct\"\n",
"# model = \"kaist-ai/prometheus-13b-v1\"\n",
"# model = \"mistralai/mistral-7b-instruct-v0-2\"\n",
"# model = \"thebloke/mixtral-8x7b-v0-1-gptq\"\n",
"# llm = LangChainChatInterface(\n",
"# client=Client(credentials=Credentials.from_env()),\n",
"# model_id=\"ibm/granite-20b-code-instruct-v1-gptq\",\n",
"# model_id=model,\n",
"# parameters=TextGenerationParameters(\n",
"# decoding_method=DecodingMethod.SAMPLE,\n",
"# max_new_tokens=4096,\n",
Expand All @@ -347,7 +364,35 @@
"\n",
"prompt = PromptTemplate.from_template(template_with_solved_example_and_analysis)\n",
"chain = LLMChain(llm=llm, prompt=prompt, output_parser=response_parser)\n",
"result = chain.invoke(template_args)"
"\n",
"model = model.replace(\"/\", \"_\")\n",
"output_dir = \"output/\" + provider + \"_\" + model + \"/few_shot_pydantic/\"\n",
"write_output_to_disk(\n",
" output_dir,\n",
" formatted_prompt,\n",
" src_file_contents,\n",
" src_file_from_quarkus_branch,\n",
")\n",
"\n",
"try:\n",
" result = chain.invoke(template_args)\n",
" with open(output_dir + \"result.txt\", \"a\") as f:\n",
" f.write(\"### Reasoning:\\n\")\n",
" for i in range(len(result[\"text\"].reasoning)):\n",
" f.write(result[\"text\"].reasoning[i])\n",
" f.write(\"\\n\")\n",
" for i in range(len(result[\"text\"].updated_files)):\n",
" f.write(\"### Updated file \" + str((i + 1)) + \"\\n\")\n",
" f.write(result[\"text\"].updated_files[i].file_name + \":\")\n",
" f.write(result[\"text\"].updated_files[i].file_contents)\n",
" print(\"Saved result in \" + output_dir + \"result.txt\")\n",
"except Exception as e:\n",
" with open(output_dir + \"traceback.txt\", \"a\") as f:\n",
" f.write(str(e))\n",
" f.write(traceback.format_exc())\n",
" print(\n",
" \"Something went wrong. A traceback was saved in \" + output_dir + \"traceback.txt\"\n",
" )"
]
},
{
Expand Down
131 changes: 131 additions & 0 deletions notebooks/jms_to_smallrye_reactive/caikit_tgis_langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import Any, Iterator, List, Mapping, Optional, Union
from warnings import warn

from caikit_nlp_client import GrpcClient, HttpClient
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.schema.output import GenerationChunk


class CaikitLLM(LLM):
inference_server_url: str
model_id: str
certificate_chain: Optional[str] = None
streaming: bool
client: HttpClient

def __init__(
self,
inference_server_url: str,
model_id: str,
certificate_chain: Optional[str] = None,
streaming: bool = False,
client: Optional[GrpcClient] = HttpClient(""),

Check failure on line 23 in notebooks/jms_to_smallrye_reactive/caikit_tgis_langchain.py

View workflow job for this annotation

GitHub Actions / Trunk Check

ruff(B008)

[new] Do not perform function call `HttpClient` in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable
):
super().__init__(
inference_server_url=inference_server_url,
model_id=model_id,
certificate_chain=certificate_chain,
streaming=streaming,
client=client,
)

self.inference_server_url = inference_server_url
self.model_id = model_id

if certificate_chain:
with open(certificate_chain, "rb") as fh:
chain = fh.read()
else:
chain = None

if inference_server_url.startswith("http"):
client = HttpClient(inference_server_url)
else:
try:
host, port = inference_server_url.split(":")
if not all((host, port)):
raise ValueError
except ValueError:
raise ValueError(

Check failure on line 50 in notebooks/jms_to_smallrye_reactive/caikit_tgis_langchain.py

View workflow job for this annotation

GitHub Actions / Trunk Check

ruff(B904)

[new] Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` to distinguish them from errors in exception handling
"Invalid url provided, must be either "
'"host:port" or "http[s]://host:port/path"'
)

client = GrpcClient(host, port, ca_cert=chain)

self.client: Union[HttpClient, GrpcClient] = client

@property
def _llm_type(self) -> str:
return "caikit_tgis"

def _call(
self,
prompt: str,
preserve_input_text: bool = False,
max_new_tokens: int = 512,
min_new_tokens: int = 10,
device: str = "",
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.streaming:
return "".join(
self._stream(
prompt=prompt,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
device=device,
stop=stop,
run_manager=run_manager,
**kwargs,
)
)
if run_manager:
warn("run_manager is ignored for non-streaming use cases")

Check failure on line 88 in notebooks/jms_to_smallrye_reactive/caikit_tgis_langchain.py

View workflow job for this annotation

GitHub Actions / Trunk Check

ruff(B028)

[new] No explicit `stacklevel` keyword argument found

if device or stop:
raise NotImplementedError()

return self.client.generate_text(
self.model_id,
prompt,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
)

def _stream(
self,
prompt: str,
preserve_input_text: bool = False,
max_new_tokens: int = 512,
min_new_tokens: int = 10,
device: str = "",
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
if device or stop:
raise NotImplementedError

for token in self.client.generate_text_stream(
self.model_id,
prompt,
preserve_input_text=preserve_input_text,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
):
chunk = GenerationChunk(text=token)
yield chunk.text

if run_manager:
run_manager.on_llm_new_token(chunk.text)

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"inference_server_url": self.inference_server_url}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* JBoss, Home of Professional Open Source
* Copyright 2015, Red Hat, Inc. and/or its affiliates, and individual
* contributors by the @authors tag. See the copyright.txt in the
* distribution for a full listing of individual contributors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jboss.as.quickstarts.mdb;

import java.util.logging.Logger;
import javax.ejb.ActivationConfigProperty;
import javax.ejb.MessageDriven;
import javax.jms.JMSException;
import javax.jms.Message;
import javax.jms.MessageListener;
import javax.jms.TextMessage;

/**
* <p>
* A simple Message Driven Bean that asynchronously receives and processes the messages that are sent to the queue.
* </p>
*
* @author Serge Pagop (spagop@redhat.com)
*/
@MessageDriven(name = "HelloWorldQueueMDB", activationConfig = {
@ActivationConfigProperty(propertyName = "destinationLookup", propertyValue = "queue/HELLOWORLDMDBQueue"),
@ActivationConfigProperty(propertyName = "destinationType", propertyValue = "javax.jms.Queue"),
@ActivationConfigProperty(propertyName = "acknowledgeMode", propertyValue = "Auto-acknowledge")})
public class HelloWorldQueueMDB implements MessageListener {

private static final Logger LOGGER = Logger.getLogger(HelloWorldQueueMDB.class.toString());

/**
* @see MessageListener#onMessage(Message)
*/
public void onMessage(Message rcvMessage) {
TextMessage msg = null;
try {
if (rcvMessage instanceof TextMessage) {
msg = (TextMessage) rcvMessage;
LOGGER.info("Received Message from queue: " + msg.getText());
} else {
LOGGER.warning("Message of wrong type: " + rcvMessage.getClass().getName());
}
} catch (JMSException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* JBoss, Home of Professional Open Source
* Copyright 2015, Red Hat, Inc. and/or its affiliates, and individual
* contributors by the @authors tag. See the copyright.txt in the
* distribution for a full listing of individual contributors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jboss.as.quickstarts.mdb;

import jakarta.enterprise.context.ApplicationScoped;
import io.smallrye.reactive.messaging.annotations.Merge;
import org.jboss.logging.Logger;

import org.eclipse.microprofile.reactive.messaging.Incoming;

@ApplicationScoped
public class HelloWorldQueueMDB {
private static final Logger LOGGER = Logger.getLogger(HelloWorldQueueMDB.class.toString());

@Incoming("HELLOWORLDMDBQueue")
@Merge
public void onMessage(String message) {
LOGGER.info("Received Message from queue: " + message);
}
}
Loading

0 comments on commit 323de79

Please sign in to comment.