Skip to content

Commit

Permalink
Fixes all clippy warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
BradyBonnette committed Jan 26, 2025
1 parent cafad0d commit c043e1c
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions candle-examples/examples/debertav2/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{Encoding, PaddingParams, Tokenizer};

enum TaskType {
NER(DebertaV2NERModel),
Ner(DebertaV2NERModel),
TextClassification(DebertaV2SeqClassificationModel),
}

#[derive(Parser, Debug, Clone, ValueEnum)]
enum ArgsTask {
/// Named Entity Recognition
NER,
Ner,

/// Text Classification
TextClassification,
Expand All @@ -36,7 +36,7 @@ enum ArgsTask {
impl Display for ArgsTask {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
ArgsTask::NER => write!(f, "ner"),
ArgsTask::Ner => write!(f, "ner"),
ArgsTask::TextClassification => write!(f, "text-classification"),
}
}
Expand Down Expand Up @@ -77,7 +77,7 @@ struct Args {
benchmark_iters: Option<usize>,

/// Which task to run
#[arg(long, default_value_t = ArgsTask::NER)]
#[arg(long, default_value_t = ArgsTask::Ner)]
task: ArgsTask,

/// Use model from a specific directory instead of HuggingFace local cache.
Expand Down Expand Up @@ -142,7 +142,7 @@ impl Args {
// Command-line id2label takes precedence. Otherwise, use model config's id2label.
// If neither is specified, then we can't proceed.
let id2label = if let Some(id2labelstr) = &self.id2label {
serde_json::from_str(&&id2labelstr.as_str())?
serde_json::from_str(id2labelstr.as_str())?
} else if let Some(id2label) = &config.id2label {
id2label.clone()
} else {
Expand Down Expand Up @@ -174,8 +174,8 @@ impl Args {
let vb = vb.set_prefix("deberta");

match self.task {
ArgsTask::NER => Ok((
TaskType::NER(DebertaV2NERModel::load(
ArgsTask::Ner => Ok((
TaskType::Ner(DebertaV2NERModel::load(
vb,
&config,
Some(id2label.clone()),
Expand All @@ -200,7 +200,7 @@ impl Args {

fn get_device(model_type: &TaskType) -> &Device {
match model_type {
TaskType::NER(ner_model) => &ner_model.device,
TaskType::Ner(ner_model) => &ner_model.device,
TaskType::TextClassification(classification_model) => &classification_model.device,
}
}
Expand Down Expand Up @@ -253,9 +253,9 @@ fn main() -> Result<()> {
let mut token_type_id_stack: Vec<Tensor> = Vec::default();

for encoding in &tokenizer_encodings {
encoding_stack.push(Tensor::new(encoding.get_ids(), &device)?);
attention_mask_stack.push(Tensor::new(encoding.get_attention_mask(), &device)?);
token_type_id_stack.push(Tensor::new(encoding.get_type_ids(), &device)?);
encoding_stack.push(Tensor::new(encoding.get_ids(), device)?);
attention_mask_stack.push(Tensor::new(encoding.get_attention_mask(), device)?);
token_type_id_stack.push(Tensor::new(encoding.get_type_ids(), device)?);
}

ModelInput {
Expand All @@ -272,7 +272,7 @@ fn main() -> Result<()> {
);

match task_type {
TaskType::NER(ner_model) => {
TaskType::Ner(ner_model) => {
if let Some(num_iters) = args.benchmark_iters {
create_benchmark(num_iters, model_input)(
|input_ids, token_type_ids, attention_mask| {
Expand Down Expand Up @@ -326,7 +326,7 @@ fn main() -> Result<()> {
current_row_result.push(NERItem {
entity: label,
word: current_row_tokens[input_id_idx].clone(),
score: current_row_max_scores[input_id_idx].clone(),
score: current_row_max_scores[input_id_idx],
start: current_row_encoding.get_offsets()[input_id_idx].0,
end: current_row_encoding.get_offsets()[input_id_idx].1,
index: input_id_idx,
Expand Down

0 comments on commit c043e1c

Please sign in to comment.