diff --git a/candle-examples/examples/debertav2/main.rs b/candle-examples/examples/debertav2/main.rs index 57742b8192..b332813abc 100644 --- a/candle-examples/examples/debertav2/main.rs +++ b/candle-examples/examples/debertav2/main.rs @@ -20,14 +20,14 @@ use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::{Encoding, PaddingParams, Tokenizer}; enum TaskType { - NER(DebertaV2NERModel), + Ner(DebertaV2NERModel), TextClassification(DebertaV2SeqClassificationModel), } #[derive(Parser, Debug, Clone, ValueEnum)] enum ArgsTask { /// Named Entity Recognition - NER, + Ner, /// Text Classification TextClassification, @@ -36,7 +36,7 @@ enum ArgsTask { impl Display for ArgsTask { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { - ArgsTask::NER => write!(f, "ner"), + ArgsTask::Ner => write!(f, "ner"), ArgsTask::TextClassification => write!(f, "text-classification"), } } @@ -77,7 +77,7 @@ struct Args { benchmark_iters: Option, /// Which task to run - #[arg(long, default_value_t = ArgsTask::NER)] + #[arg(long, default_value_t = ArgsTask::Ner)] task: ArgsTask, /// Use model from a specific directory instead of HuggingFace local cache. @@ -142,7 +142,7 @@ impl Args { // Command-line id2label takes precedence. Otherwise, use model config's id2label. // If neither is specified, then we can't proceed. let id2label = if let Some(id2labelstr) = &self.id2label { - serde_json::from_str(&&id2labelstr.as_str())? + serde_json::from_str(id2labelstr.as_str())? } else if let Some(id2label) = &config.id2label { id2label.clone() } else { @@ -174,8 +174,8 @@ impl Args { let vb = vb.set_prefix("deberta"); match self.task { - ArgsTask::NER => Ok(( - TaskType::NER(DebertaV2NERModel::load( + ArgsTask::Ner => Ok(( + TaskType::Ner(DebertaV2NERModel::load( vb, &config, Some(id2label.clone()), @@ -200,7 +200,7 @@ impl Args { fn get_device(model_type: &TaskType) -> &Device { match model_type { - TaskType::NER(ner_model) => &ner_model.device, + TaskType::Ner(ner_model) => &ner_model.device, TaskType::TextClassification(classification_model) => &classification_model.device, } } @@ -253,9 +253,9 @@ fn main() -> Result<()> { let mut token_type_id_stack: Vec = Vec::default(); for encoding in &tokenizer_encodings { - encoding_stack.push(Tensor::new(encoding.get_ids(), &device)?); - attention_mask_stack.push(Tensor::new(encoding.get_attention_mask(), &device)?); - token_type_id_stack.push(Tensor::new(encoding.get_type_ids(), &device)?); + encoding_stack.push(Tensor::new(encoding.get_ids(), device)?); + attention_mask_stack.push(Tensor::new(encoding.get_attention_mask(), device)?); + token_type_id_stack.push(Tensor::new(encoding.get_type_ids(), device)?); } ModelInput { @@ -272,7 +272,7 @@ fn main() -> Result<()> { ); match task_type { - TaskType::NER(ner_model) => { + TaskType::Ner(ner_model) => { if let Some(num_iters) = args.benchmark_iters { create_benchmark(num_iters, model_input)( |input_ids, token_type_ids, attention_mask| { @@ -326,7 +326,7 @@ fn main() -> Result<()> { current_row_result.push(NERItem { entity: label, word: current_row_tokens[input_id_idx].clone(), - score: current_row_max_scores[input_id_idx].clone(), + score: current_row_max_scores[input_id_idx], start: current_row_encoding.get_offsets()[input_id_idx].0, end: current_row_encoding.get_offsets()[input_id_idx].1, index: input_id_idx,