Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StableLM-2, StableLM Code and Zephyr variants #1650

Merged
merged 3 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions candle-examples/examples/stable-lm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t).
Note that this model is gated so you will have to request access on the Hub in
order to be able to use it.

Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.

StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by Candle, so to run it you can download a somewhat compatible [tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
and pass it via the --tokenizer-file argument.

## Running some example

```bash
Expand Down
61 changes: 51 additions & 10 deletions candle-examples/examples/stable-lm/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;

use anyhow::{Error as E, Result};
use clap::Parser;
use clap::{Parser, ValueEnum};

use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
Expand Down Expand Up @@ -122,6 +122,16 @@ impl TextGeneration {
}
}

#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
enum Which {
V1Orig,
V1,
V1Zephyr,
V2,
V2Zephyr,
Code,
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
Expand Down Expand Up @@ -155,12 +165,15 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,

#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
model_id: String,
#[arg(long)]
model_id: Option<String>,

#[arg(long, default_value = "main")]
revision: String,

#[arg(long, default_value = "v1-orig")]
which: Which,

#[arg(long)]
tokenizer_file: Option<String>,

Expand Down Expand Up @@ -207,8 +220,20 @@ fn main() -> Result<()> {

let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => match args.which {
Which::V1Orig => "lmz/candle-stablelm-3b-4e1t".to_string(),
Which::V1 => "stabilityai/stablelm-3b-4e1t".to_string(),
Which::V1Zephyr => "stabilityai/stablelm-zephyr-3b".to_string(),
Which::Code => "stabilityai/stable-code-3b".to_string(),
Which::V2 => "stabilityai/stablelm-2-1_6b".to_string(),
Which::V2Zephyr => "stabilityai/stablelm-2-zephyr-1_6b".to_string(),
},
};

let repo = api.repo(Repo::with_revision(
args.model_id,
model_id,
RepoType::Model,
args.revision,
));
Expand All @@ -221,19 +246,35 @@ fn main() -> Result<()> {
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
if args.quantized {
vec![repo.get("model-q4k.gguf")?]
} else {
None => match (args.which, args.quantized) {
(Which::V1Orig, true) => vec![repo.get("model-q4k.gguf")?],
(Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code, true) => {
anyhow::bail!("Quantized {:?} variant not supported.", args.which)
}
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {
vec![repo.get("model.safetensors")?]
}
}
(Which::Code, false) => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
};

println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

let start = std::time::Instant::now();
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
let config = match args.which {
Which::V1Orig => Config::stablelm_3b_4e1t(args.use_flash_attn),
Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let mut config: Config = serde_json::from_str(&config)?;
config.set_use_flash_attn(args.use_flash_attn);
config
}
};

let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let filename = &filenames[0];
Expand Down
27 changes: 21 additions & 6 deletions candle-transformers/src/models/stable_lm.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::models::with_tracing::{linear_no_bias, Linear};
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm, VarBuilder};
use serde::Deserialize;
use std::sync::Arc;

// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) intermediate_size: usize,
Expand All @@ -18,7 +19,10 @@ pub struct Config {
pub(crate) max_position_embeddings: usize,
pub(crate) norm_eps: f64,
pub(crate) use_cache: bool,
pub(crate) use_flash_attn: bool,
#[serde(default)]
pub(crate) use_qkv_bias: bool, // Used in StableLM-2
#[serde(default)]
pub(crate) use_flash_attn: bool, // Not in config.json
}

impl Config {
Expand All @@ -35,6 +39,7 @@ impl Config {
rope_theta: 10_000.,
max_position_embeddings: 4096,
norm_eps: 1e-5,
use_qkv_bias: false,
use_cache: true,
use_flash_attn,
}
Expand All @@ -51,6 +56,10 @@ impl Config {
pub fn num_kv_groups(&self) -> usize {
self.num_attention_heads / self.num_key_value_heads
}

pub fn set_use_flash_attn(&mut self, use_flash_attn: bool) {
self.use_flash_attn = use_flash_attn
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -179,9 +188,15 @@ impl Attention {
let head_dim = cfg.head_dim();
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let linear_layer = if cfg.use_qkv_bias {
linear
} else {
linear_no_bias
};

let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
Expand Down
Loading