diff --git a/README.md b/README.md index 5bfe5c8..2193050 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,25 @@ The system prompt is - StableLM will refuse to participate in anything that could harm a human. ``` +## Gradio Web UI +You can also run a Gradio based Web UI, found under the `gradio` folder. + +![Web UI Screenshot](assets/webui-screenshot.png) +To get started and run the UI, clone the project and install the dependencies: +``` +git clone https://github.com/Stability-AI/StableLM +cd StableLM +pip3 install -r gradio/requirements.txt +``` +Then, run the UI: +``` +python3 gradio/webui.py +``` +A message will appear in the console with a link to the UI. Click on the link to open the UI in your browser. + + + + ## Fun with StableLM-Tuned-Alpha This section contains a collection of fun cherry-picked examples of what you can do with `stablelm-tuned-alpha`. diff --git a/assets/webui-screenshot.png b/assets/webui-screenshot.png new file mode 100644 index 0000000..24da903 Binary files /dev/null and b/assets/webui-screenshot.png differ diff --git a/gradio/requirements.txt b/gradio/requirements.txt new file mode 100644 index 0000000..b4e4943 --- /dev/null +++ b/gradio/requirements.txt @@ -0,0 +1,3 @@ +gradio==3.28.0 +torch==2.0.0 +transformers==4.28.1 diff --git a/gradio/webui.py b/gradio/webui.py new file mode 100644 index 0000000..da9c6a2 --- /dev/null +++ b/gradio/webui.py @@ -0,0 +1,107 @@ +import gradio as gr +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + StoppingCriteria, + StoppingCriteriaList, +) + + +model_name = "stabilityai/stablelm-tuned-alpha-7b" # @param ["stabilityai/stablelm-tuned-alpha-7b", "stabilityai/stablelm-base-alpha-7b", "stabilityai/stablelm-tuned-alpha-3b", "stabilityai/stablelm-base-alpha-3b"] + +print(f"Using `{model_name}`") + +# Select "big model inference" parameters +torch_dtype = "float16" # @param ["float16", "bfloat16", "float"] +load_in_8bit = False # @param {type:"boolean"} +device_map = "auto" + +print(f"Loading with: `{torch_dtype=}, {load_in_8bit=}, {device_map=}`") + +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=getattr(torch, torch_dtype), + load_in_8bit=load_in_8bit, + device_map=device_map, + offload_folder="./offload", +) + +# @title Generate Text +# @markdown Note: The model response is colored in green + + +class StopOnTokens(StoppingCriteria): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + stop_ids = [50278, 50279, 50277, 1, 0] + for stop_id in stop_ids: + if input_ids[0][-1] == stop_id: + return True + return False + + +# Process the user prompt +user_prompt = "Can you write a song about a pirate at sea?" # @param {type:"string"} +if "tuned" in model_name: + # Add system prompt for chat tuned models + system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version) + - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. + - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. + - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. + - StableLM will refuse to participate in anything that could harm a human. + """ + prompt = f"{system_prompt}<|USER|>{user_prompt}<|ASSISTANT|>" +else: + prompt = user_prompt + + +def complete( + prompt, max_new_tokens=128, temperature=0.7, top_k=0, top_p=0.9, do_sample=True +): + print( + f"Sampling with: `{max_new_tokens=}, {temperature=}, {top_k=}, {top_p=}, {do_sample=}`" + ) + + # Create `generate` inputs + inputs = tokenizer(prompt, return_tensors="pt") + inputs.to(model.device) + + # Generate + tokens = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + pad_token_id=tokenizer.eos_token_id, + stopping_criteria=StoppingCriteriaList([StopOnTokens()]), + ) + + # Extract out only the completion tokens + completion_tokens = tokens[0][inputs["input_ids"].size(1) :] + completion = tokenizer.decode(completion_tokens, skip_special_tokens=True) + + return prompt + " " + completion + + +iface = gr.Interface( + fn=complete, + inputs=[ + "text", + gr.Slider(32, 3072, 32, default=128), + gr.Slider(0.0, 1.25, 0.05, default=0.7), + gr.Slider(0.0, 1.0, 0.05, default=0.0), + gr.Slider(0.0, 1.0, 0.05, default=0.9), + "checkbox", + ], + outputs="text", + share=True, + title="StableLM", + description="StableLM web ui", +) + +iface.launch(server_port=8888)