Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
jamjamjon committed Sep 30, 2024
1 parent 9939faa commit cfeee00
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 71 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "usls"
version = "0.0.16"
version = "0.0.17"
edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"
Expand Down
38 changes: 21 additions & 17 deletions examples/yolo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,34 @@
cargo run -r --example yolo -- --task detect --ver v8 --nc 6 --model xxx.onnx # YOLOv8

# Classify
cargo run -r --example yolo -- --task classify --ver v5 --scale n --width 224 --height 224 --nc 1000 # YOLOv5
cargo run -r --example yolo -- --task classify --ver v5 --scale s --width 224 --height 224 --nc 1000 # YOLOv5
cargo run -r --example yolo -- --task classify --ver v8 --scale n --width 224 --height 224 --nc 1000 # YOLOv8
cargo run -r --example yolo -- --task classify --ver v11 --scale n --width 224 --height 224 --nc 1000 # YOLOv11

# Detect
cargo run -r --example yolo -- --task detect --ver v5 --scale n --nc 80 # YOLOv5
cargo run -r --example yolo -- --task detect --ver v6 --scale n --nc 80 # YOLOv6
cargo run -r --example yolo -- --task detect --ver v7 --scale t --nc 80 # YOLOv7
cargo run -r --example yolo -- --task detect --ver v8 --scale n --nc 80 # YOLOv8
cargo run -r --example yolo -- --task detect --ver v9 --scale t --nc 80 # YOLOv9
cargo run -r --example yolo -- --task detect --ver v10 --scale n --nc 80 # YOLOv10
cargo run -r --example yolo -- --task detect --ver v11 --scale n --nc 80 # YOLOv11
cargo run -r --example yolo -- --task detect --ver rtdetr --scale l --nc 80 # RTDETR
cargo run -r --example yolo -- --task detect --ver v8 --nc 1 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world
cargo run -r --example yolo -- --task detect --ver v5 --scale n # YOLOv5
cargo run -r --example yolo -- --task detect --ver v6 --scale n # YOLOv6
cargo run -r --example yolo -- --task detect --ver v7 --scale t # YOLOv7
cargo run -r --example yolo -- --task detect --ver v8 --scale n # YOLOv8
cargo run -r --example yolo -- --task detect --ver v9 --scale t # YOLOv9
cargo run -r --example yolo -- --task detect --ver v10 --scale n # YOLOv10
cargo run -r --example yolo -- --task detect --ver v11 --scale n # YOLOv11
cargo run -r --example yolo -- --task detect --ver rtdetr --scale l # RTDETR
cargo run -r --example yolo -- --task detect --ver v8 --nc 1 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world <local file>

# Pose
cargo run -r --example yolo -- --task pose --ver v8 --scale n --nc 1 # YOLOv8-Pose
cargo run -r --example yolo -- --task pose --ver v11 --scale n --nc 1 # YOLOv11-Pose
cargo run -r --example yolo -- --task pose --ver v8 --scale n # YOLOv8-Pose
cargo run -r --example yolo -- --task pose --ver v11 --scale n # YOLOv11-Pose

# Segment
cargo run -r --example yolo -- --task segment --ver v5 --scale n --nc 80 # YOLOv5-Segment
cargo run -r --example yolo -- --task segment --ver v8 --scale n --nc 80 # YOLOv8-Segment
cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM
cargo run -r --example yolo -- --task segment --ver v5 --scale n # YOLOv5-Segment
cargo run -r --example yolo -- --task segment --ver v8 --scale n # YOLOv8-Segment
cargo run -r --example yolo -- --task segment --ver v11 --scale n # YOLOv8-Segment
cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM <local file>

# Obb
cargo run -r --example yolo -- --ver v8 --task obb --scale s --source images/dota.png # YOLOv8-Obb
cargo run -r --example yolo -- --ver v11 --task obb --scale s --source images/dota.png # YOLOv11-Obb
cargo run -r --example yolo -- --ver v8 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv8-Obb
cargo run -r --example yolo -- --ver v11 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv11-Obb
```

**`cargo run -r --example yolo -- --help` for more options**
Expand All @@ -68,6 +70,8 @@ cargo run -r --example yolo -- --ver v11 --task obb --scale s --source images/do
let options = Options::default()
.with_yolo_version(YOLOVersion::V5) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
.with_yolo_task(YOLOTask::Classify) // YOLOTask: Classify, Detect, Pose, Segment, Obb
// .with_nc(80)
// .with_names(&COCO_CLASS_NAMES_80)
.with_model("xxxx.onnx")?;

```
Expand Down
9 changes: 5 additions & 4 deletions examples/yolo/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use clap::Parser;

use usls::{
models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOScale, YOLOTask,
YOLOVersion, COCO_KEYPOINTS_17, COCO_SKELETONS_16,
YOLOVersion, COCO_SKELETONS_16,
};

#[derive(Parser, Clone)]
Expand Down Expand Up @@ -42,7 +42,7 @@ pub struct Args {
pub width: isize,

/// Maximum input width
#[arg(long, default_value_t = 800)]
#[arg(long, default_value_t = 1024)]
pub width_max: isize,

/// Minimum input height
Expand All @@ -54,7 +54,7 @@ pub struct Args {
pub height: isize,

/// Maximum input height
#[arg(long, default_value_t = 800)]
#[arg(long, default_value_t = 1024)]
pub height_max: isize,

/// Number of classes
Expand Down Expand Up @@ -151,7 +151,7 @@ fn main() -> Result<()> {
})
.with_nc(args.nc)
// .with_names(&COCO_CLASS_NAMES_80)
.with_names2(&COCO_KEYPOINTS_17)
// .with_names2(&COCO_KEYPOINTS_17)
.with_find_contours(!args.no_contours) // find contours or not
.with_profile(args.profile);

Expand All @@ -168,6 +168,7 @@ fn main() -> Result<()> {
.with_skeletons(&COCO_SKELETONS_16)
.without_masks(true) // No masks plotting when doing segment task.
.with_bboxes_thickness(3)
.with_keypoints_name(false) // Enable keypoints names
.with_saveout_subs(&["YOLO"])
.with_saveout(&saveout);

Expand Down
124 changes: 75 additions & 49 deletions src/models/yolo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ pub struct YOLO {
confs: DynConf,
kconfs: DynConf,
iou: f32,
names: Option<Vec<String>>,
names_kpt: Option<Vec<String>>,
names: Vec<String>,
names_kpt: Vec<String>,
task: YOLOTask,
layout: YOLOPreds,
find_contours: bool,
Expand Down Expand Up @@ -96,49 +96,71 @@ impl Vision for YOLO {

let task = task.unwrap_or(layout.task());

// The number of classes & Class names
let mut names = options.names.or(Self::fetch_names(&engine));
let nc = match options.nc {
Some(nc) => {
match &names {
None => names = Some((0..nc).map(|x| x.to_string()).collect::<Vec<String>>()),
Some(names) => {
assert_eq!(
nc,
// Class names: user-defined.or(parsed)
let names_parsed = Self::fetch_names(&engine);
let names = match names_parsed {
Some(names_parsed) => match options.names {
Some(names) => {
if names.len() == names_parsed.len() {
Some(names)
} else {
anyhow::bail!(
"The lengths of parsed class names: {} and user-defined class names: {} do not match.",
names_parsed.len(),
names.len(),
"The length of `nc` and `class names` is not equal."
);
}
}
nc
}
None => match &names {
Some(names) => names.len(),
None => panic!(
"Can not parse model without `nc` and `class names`. Try to make it explicit with `options.with_nc(80)`"
),
None => Some(names_parsed),
},
None => options.names,
};

// Keypoints names
let names_kpt = options.names2;
// nc: names.len().or(options.nc)
let nc = match &names {
Some(names) => names.len(),
None => match options.nc {
Some(nc) => nc,
None => anyhow::bail!(
"Unable to obtain the number of classes. Please specify them explicitly using `options.with_nc(usize)` or `options.with_names(&[&str])`."
),
}
};

// The number of keypoints
let nk = engine
.try_fetch("kpt_shape")
.map(|kpt_string| {
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
let caps = re.captures(&kpt_string).unwrap();
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
})
.unwrap_or(0_usize);
// Class names
let names = match names {
None => Self::n2s(nc),
Some(names) => names,
};

// Keypoint names & nk
let (nk, names_kpt) = match Self::fetch_kpts(&engine) {
None => (0, vec![]),
Some(nk) => match options.names2 {
Some(names) => {
if names.len() == nk {
(nk, names)
} else {
anyhow::bail!(
"The lengths of user-defined keypoint names: {} and nk: {} do not match.",
names.len(),
nk,
);
}
}
None => (nk, Self::n2s(nk)),
},
};

// Confs & Iou
let confs = DynConf::new(&options.confs, nc);
let kconfs = DynConf::new(&options.kconfs, nk);
let iou = options.iou.unwrap_or(0.45);

// Summary
tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version);

// dry run
engine.dry_run()?;

Ok(Self {
Expand Down Expand Up @@ -218,10 +240,8 @@ impl Vision for YOLO {
slice_clss.into_owned()
};
let mut probs = Prob::default().with_probs(&x.into_raw_vec_and_offset().0);
if let Some(names) = &self.names {
probs =
probs.with_names(&names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
}
probs = probs
.with_names(&self.names.iter().map(|x| x.as_str()).collect::<Vec<_>>());

return Some(y.with_probs(&probs));
}
Expand Down Expand Up @@ -324,9 +344,7 @@ impl Vision for YOLO {
)
.with_confidence(confidence)
.with_id(class_id as isize);
if let Some(names) = &self.names {
mbr = mbr.with_name(&names[class_id]);
}
mbr = mbr.with_name(&self.names[class_id]);

(None, Some(mbr))
}
Expand All @@ -336,9 +354,7 @@ impl Vision for YOLO {
.with_confidence(confidence)
.with_id(class_id as isize)
.with_id_born(i as isize);
if let Some(names) = &self.names {
bbox = bbox.with_name(&names[class_id]);
}
bbox = bbox.with_name(&self.names[class_id]);

(Some(bbox), None)
}
Expand Down Expand Up @@ -393,9 +409,7 @@ impl Vision for YOLO {
ky.max(0.0f32).min(image_height),
);

if let Some(names) = &self.names_kpt {
kpt = kpt.with_name(&names[i]);
}
kpt = kpt.with_name(&self.names_kpt[i]);
kpt
}
})
Expand Down Expand Up @@ -504,16 +518,16 @@ impl Vision for YOLO {
}

impl YOLO {
pub fn batch(&self) -> isize {
self.batch.opt() as _
pub fn batch(&self) -> usize {
self.batch.opt()
}

pub fn width(&self) -> isize {
self.width.opt() as _
pub fn width(&self) -> usize {
self.width.opt()
}

pub fn height(&self) -> isize {
self.height.opt() as _
pub fn height(&self) -> usize {
self.height.opt()
}

pub fn version(&self) -> Option<&YOLOVersion> {
Expand All @@ -540,4 +554,16 @@ impl YOLO {
names_
})
}

fn fetch_kpts(engine: &OrtEngine) -> Option<usize> {
engine.try_fetch("kpt_shape").map(|s| {
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
let caps = re.captures(&s).unwrap();
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
})
}

fn n2s(n: usize) -> Vec<String> {
(0..n).map(|x| format!("# {}", x)).collect::<Vec<String>>()
}
}

0 comments on commit cfeee00

Please sign in to comment.