diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index ea99c706bf..69eed84ff2 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -128,9 +128,8 @@ enum WhichModel { V1_5, #[value(name = "2")] V2, - // TODO: Make this the default once it has been battle tested. - #[value(name = "2-new")] - V2New, + #[value(name = "2-old")] + V2Old, PuffinPhiV2, PhiHermes, } @@ -236,7 +235,7 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "microsoft/phi-1".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), - WhichModel::V2 | WhichModel::V2New => "microsoft/phi-2".to_string(), + WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(), WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { "lmz/candle-quantized-phi".to_string() } @@ -251,10 +250,10 @@ fn main() -> Result<()> { "main".to_string() } else { match args.model { - WhichModel::V1 => "refs/pr/2".to_string(), - WhichModel::V1_5 => "refs/pr/18".to_string(), - WhichModel::V2 => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(), - WhichModel::V2New | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { + WhichModel::V1 => "refs/pr/8".to_string(), + WhichModel::V1_5 => "refs/pr/73".to_string(), + WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(), + WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { "main".to_string() } } @@ -265,7 +264,7 @@ fn main() -> Result<()> { let tokenizer_filename = match args.tokenizer { Some(file) => std::path::PathBuf::from(file), None => match args.model { - WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2New => { + WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => { repo.get("tokenizer.json")? } WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { @@ -280,14 +279,14 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?], WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?], - WhichModel::V2 | WhichModel::V2New => vec![repo.get("model-v2-q4k.gguf")?], + WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?], WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?], WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?], } } else { match args.model { WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?], - WhichModel::V2 | WhichModel::V2New => candle_examples::hub_load_safetensors( + WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors( &repo, "model.safetensors.index.json", )?, @@ -304,35 +303,39 @@ fn main() -> Result<()> { let config = || match args.model { WhichModel::V1 => Config::v1(), WhichModel::V1_5 => Config::v1_5(), - WhichModel::V2 | WhichModel::V2New => Config::v2(), + WhichModel::V2 | WhichModel::V2Old => Config::v2(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; - let (model, device) = if args.model == WhichModel::V2New { - let device = candle_examples::device(args.cpu)?; - let config_filename = repo.get("config.json")?; - let config = std::fs::read_to_string(config_filename)?; - let config: PhiConfig = serde_json::from_str(&config)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - let phi = Phi::new(&config, vb)?; - (Model::Phi(phi), device) - } else if args.quantized { + let (model, device) = if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; let config = config(); let model = match args.model { - WhichModel::V2 | WhichModel::V2New => QMixFormer::new_v2(&config, vb)?, + WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?, _ => QMixFormer::new(&config, vb)?, }; (Model::Quantized(model), Device::Cpu) } else { let device = candle_examples::device(args.cpu)?; - let config = config(); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let model = match args.model { - WhichModel::V2 | WhichModel::V2New => MixFormer::new_v2(&config, vb)?, - _ => MixFormer::new(&config, vb)?, + WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => { + let config_filename = repo.get("config.json")?; + let config = std::fs::read_to_string(config_filename)?; + let config: PhiConfig = serde_json::from_str(&config)?; + let phi = Phi::new(&config, vb)?; + Model::Phi(phi) + } + WhichModel::V2Old => { + let config = config(); + Model::MixFormer(MixFormer::new_v2(&config, vb)?) + } + WhichModel::PhiHermes | WhichModel::PuffinPhiV2 => { + let config = config(); + Model::MixFormer(MixFormer::new(&config, vb)?) + } }; - (Model::MixFormer(model), device) + (model, device) }; println!("loaded the model in {:?}", start.elapsed());