Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/to gradio #63

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,25 @@ The system prompt is
- StableLM will refuse to participate in anything that could harm a human.
```

## Gradio Web UI
You can also run a Gradio based Web UI, found under the `gradio` folder.

![Web UI Screenshot](assets/webui-screenshot.png)
To get started and run the UI, clone the project and install the dependencies:
```
git clone https://github.com/Stability-AI/StableLM
cd StableLM
pip3 install -r gradio/requirements.txt
```
Then, run the UI:
```
python3 gradio/webui.py
```
A message will appear in the console with a link to the UI. Click on the link to open the UI in your browser.




## Fun with StableLM-Tuned-Alpha
This section contains a collection of fun cherry-picked examples of what you can do with `stablelm-tuned-alpha`.

Expand Down
Binary file added assets/webui-screenshot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions gradio/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
gradio==3.28.0
torch==2.0.0
transformers==4.28.1
107 changes: 107 additions & 0 deletions gradio/webui.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to use isort and black modules to reformatting Python code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in 5161818

Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)


model_name = "stabilityai/stablelm-tuned-alpha-7b" # @param ["stabilityai/stablelm-tuned-alpha-7b", "stabilityai/stablelm-base-alpha-7b", "stabilityai/stablelm-tuned-alpha-3b", "stabilityai/stablelm-base-alpha-3b"]

print(f"Using `{model_name}`")

# Select "big model inference" parameters
torch_dtype = "float16" # @param ["float16", "bfloat16", "float"]
load_in_8bit = False # @param {type:"boolean"}
device_map = "auto"

print(f"Loading with: `{torch_dtype=}, {load_in_8bit=}, {device_map=}`")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=getattr(torch, torch_dtype),
load_in_8bit=load_in_8bit,
device_map=device_map,
offload_folder="./offload",
)

# @title Generate Text
# @markdown <b>Note: The model response is colored in green</b>


class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False


# Process the user prompt
user_prompt = "Can you write a song about a pirate at sea?" # @param {type:"string"}
if "tuned" in model_name:
# Add system prompt for chat tuned models
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
prompt = f"{system_prompt}<|USER|>{user_prompt}<|ASSISTANT|>"
else:
prompt = user_prompt


def complete(
prompt, max_new_tokens=128, temperature=0.7, top_k=0, top_p=0.9, do_sample=True
):
print(
f"Sampling with: `{max_new_tokens=}, {temperature=}, {top_k=}, {top_p=}, {do_sample=}`"
)

# Create `generate` inputs
inputs = tokenizer(prompt, return_tensors="pt")
inputs.to(model.device)

# Generate
tokens = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
pad_token_id=tokenizer.eos_token_id,
stopping_criteria=StoppingCriteriaList([StopOnTokens()]),
)

# Extract out only the completion tokens
completion_tokens = tokens[0][inputs["input_ids"].size(1) :]
completion = tokenizer.decode(completion_tokens, skip_special_tokens=True)

return prompt + " " + completion


iface = gr.Interface(
fn=complete,
inputs=[
"text",
gr.Slider(32, 3072, 32, default=128),
gr.Slider(0.0, 1.25, 0.05, default=0.7),
gr.Slider(0.0, 1.0, 0.05, default=0.0),
gr.Slider(0.0, 1.0, 0.05, default=0.9),
"checkbox",
],
outputs="text",
share=True,
title="StableLM",
description="StableLM web ui",
)

iface.launch(server_port=8888)