diff --git a/candle-examples/examples/stable-lm/README.md b/candle-examples/examples/stable-lm/README.md index ad3e4a5b60..485812d3e5 100644 --- a/candle-examples/examples/stable-lm/README.md +++ b/candle-examples/examples/stable-lm/README.md @@ -8,6 +8,11 @@ Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t). Note that this model is gated so you will have to request access on the Hub in order to be able to use it. +Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants. + +StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by Candle, so to run it you can download a somewhat compatible [tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true) +and pass it via the --tokenizer-file argument. + ## Running some example ```bash diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs index ccd924a40f..415c6e7e7c 100644 --- a/candle-examples/examples/stable-lm/main.rs +++ b/candle-examples/examples/stable-lm/main.rs @@ -5,7 +5,7 @@ extern crate intel_mkl_src; extern crate accelerate_src; use anyhow::{Error as E, Result}; -use clap::Parser; +use clap::{Parser, ValueEnum}; use candle_transformers::models::quantized_stable_lm::Model as QStableLM; use candle_transformers::models::stable_lm::{Config, Model as StableLM}; @@ -122,6 +122,16 @@ impl TextGeneration { } } +#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)] +enum Which { + V1Orig, + V1, + V1Zephyr, + V2, + V2Zephyr, + Code, +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -155,12 +165,15 @@ struct Args { #[arg(long, short = 'n', default_value_t = 100)] sample_len: usize, - #[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")] - model_id: String, + #[arg(long)] + model_id: Option, #[arg(long, default_value = "main")] revision: String, + #[arg(long, default_value = "v1-orig")] + which: Which, + #[arg(long)] tokenizer_file: Option, @@ -207,8 +220,20 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let api = Api::new()?; + let model_id = match args.model_id { + Some(model_id) => model_id, + None => match args.which { + Which::V1Orig => "lmz/candle-stablelm-3b-4e1t".to_string(), + Which::V1 => "stabilityai/stablelm-3b-4e1t".to_string(), + Which::V1Zephyr => "stabilityai/stablelm-zephyr-3b".to_string(), + Which::Code => "stabilityai/stable-code-3b".to_string(), + Which::V2 => "stabilityai/stablelm-2-1_6b".to_string(), + Which::V2Zephyr => "stabilityai/stablelm-2-zephyr-1_6b".to_string(), + }, + }; + let repo = api.repo(Repo::with_revision( - args.model_id, + model_id, RepoType::Model, args.revision, )); @@ -221,19 +246,35 @@ fn main() -> Result<()> { .split(',') .map(std::path::PathBuf::from) .collect::>(), - None => { - if args.quantized { - vec![repo.get("model-q4k.gguf")?] - } else { + None => match (args.which, args.quantized) { + (Which::V1Orig, true) => vec![repo.get("model-q4k.gguf")?], + (Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code, true) => { + anyhow::bail!("Quantized {:?} variant not supported.", args.which) + } + (Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => { vec![repo.get("model.safetensors")?] } - } + (Which::Code, false) => { + candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? + } + }, }; + println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); - let config = Config::stablelm_3b_4e1t(args.use_flash_attn); + let config = match args.which { + Which::V1Orig => Config::stablelm_3b_4e1t(args.use_flash_attn), + Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code => { + let config_filename = repo.get("config.json")?; + let config = std::fs::read_to_string(config_filename)?; + let mut config: Config = serde_json::from_str(&config)?; + config.set_use_flash_attn(args.use_flash_attn); + config + } + }; + let device = candle_examples::device(args.cpu)?; let (model, device) = if args.quantized { let filename = &filenames[0]; diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs index ef06ea99c0..a49b82820a 100644 --- a/candle-transformers/src/models/stable_lm.rs +++ b/candle-transformers/src/models/stable_lm.rs @@ -1,10 +1,11 @@ -use crate::models::with_tracing::{linear_no_bias, Linear}; +use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, LayerNorm, VarBuilder}; +use serde::Deserialize; use std::sync::Arc; // https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { pub(crate) vocab_size: usize, pub(crate) intermediate_size: usize, @@ -18,7 +19,10 @@ pub struct Config { pub(crate) max_position_embeddings: usize, pub(crate) norm_eps: f64, pub(crate) use_cache: bool, - pub(crate) use_flash_attn: bool, + #[serde(default)] + pub(crate) use_qkv_bias: bool, // Used in StableLM-2 + #[serde(default)] + pub(crate) use_flash_attn: bool, // Not in config.json } impl Config { @@ -35,6 +39,7 @@ impl Config { rope_theta: 10_000., max_position_embeddings: 4096, norm_eps: 1e-5, + use_qkv_bias: false, use_cache: true, use_flash_attn, } @@ -51,6 +56,10 @@ impl Config { pub fn num_kv_groups(&self) -> usize { self.num_attention_heads / self.num_key_value_heads } + + pub fn set_use_flash_attn(&mut self, use_flash_attn: bool) { + self.use_flash_attn = use_flash_attn + } } #[derive(Debug)] @@ -179,9 +188,15 @@ impl Attention { let head_dim = cfg.head_dim(); let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; - let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let linear_layer = if cfg.use_qkv_bias { + linear + } else { + linear_no_bias + }; + + let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; Ok(Self { q_proj,