diff --git a/Cargo.toml b/Cargo.toml index 867f42d..6c2c565 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "usls" -version = "0.0.16" +version = "0.0.17" edition = "2021" description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." repository = "https://github.com/jamjamjon/usls" diff --git a/examples/yolo/README.md b/examples/yolo/README.md index fce2d63..7ca9ec8 100644 --- a/examples/yolo/README.md +++ b/examples/yolo/README.md @@ -28,32 +28,34 @@ cargo run -r --example yolo -- --task detect --ver v8 --nc 6 --model xxx.onnx # YOLOv8 # Classify -cargo run -r --example yolo -- --task classify --ver v5 --scale n --width 224 --height 224 --nc 1000 # YOLOv5 +cargo run -r --example yolo -- --task classify --ver v5 --scale s --width 224 --height 224 --nc 1000 # YOLOv5 cargo run -r --example yolo -- --task classify --ver v8 --scale n --width 224 --height 224 --nc 1000 # YOLOv8 +cargo run -r --example yolo -- --task classify --ver v11 --scale n --width 224 --height 224 --nc 1000 # YOLOv11 # Detect -cargo run -r --example yolo -- --task detect --ver v5 --scale n --nc 80 # YOLOv5 -cargo run -r --example yolo -- --task detect --ver v6 --scale n --nc 80 # YOLOv6 -cargo run -r --example yolo -- --task detect --ver v7 --scale t --nc 80 # YOLOv7 -cargo run -r --example yolo -- --task detect --ver v8 --scale n --nc 80 # YOLOv8 -cargo run -r --example yolo -- --task detect --ver v9 --scale t --nc 80 # YOLOv9 -cargo run -r --example yolo -- --task detect --ver v10 --scale n --nc 80 # YOLOv10 -cargo run -r --example yolo -- --task detect --ver v11 --scale n --nc 80 # YOLOv11 -cargo run -r --example yolo -- --task detect --ver rtdetr --scale l --nc 80 # RTDETR -cargo run -r --example yolo -- --task detect --ver v8 --nc 1 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world +cargo run -r --example yolo -- --task detect --ver v5 --scale n # YOLOv5 +cargo run -r --example yolo -- --task detect --ver v6 --scale n # YOLOv6 +cargo run -r --example yolo -- --task detect --ver v7 --scale t # YOLOv7 +cargo run -r --example yolo -- --task detect --ver v8 --scale n # YOLOv8 +cargo run -r --example yolo -- --task detect --ver v9 --scale t # YOLOv9 +cargo run -r --example yolo -- --task detect --ver v10 --scale n # YOLOv10 +cargo run -r --example yolo -- --task detect --ver v11 --scale n # YOLOv11 +cargo run -r --example yolo -- --task detect --ver rtdetr --scale l # RTDETR +cargo run -r --example yolo -- --task detect --ver v8 --nc 1 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world # Pose -cargo run -r --example yolo -- --task pose --ver v8 --scale n --nc 1 # YOLOv8-Pose -cargo run -r --example yolo -- --task pose --ver v11 --scale n --nc 1 # YOLOv11-Pose +cargo run -r --example yolo -- --task pose --ver v8 --scale n # YOLOv8-Pose +cargo run -r --example yolo -- --task pose --ver v11 --scale n # YOLOv11-Pose # Segment -cargo run -r --example yolo -- --task segment --ver v5 --scale n --nc 80 # YOLOv5-Segment -cargo run -r --example yolo -- --task segment --ver v8 --scale n --nc 80 # YOLOv8-Segment -cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM +cargo run -r --example yolo -- --task segment --ver v5 --scale n # YOLOv5-Segment +cargo run -r --example yolo -- --task segment --ver v8 --scale n # YOLOv8-Segment +cargo run -r --example yolo -- --task segment --ver v11 --scale n # YOLOv8-Segment +cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM # Obb -cargo run -r --example yolo -- --ver v8 --task obb --scale s --source images/dota.png # YOLOv8-Obb -cargo run -r --example yolo -- --ver v11 --task obb --scale s --source images/dota.png # YOLOv11-Obb +cargo run -r --example yolo -- --ver v8 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv8-Obb +cargo run -r --example yolo -- --ver v11 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv11-Obb ``` **`cargo run -r --example yolo -- --help` for more options** @@ -68,6 +70,8 @@ cargo run -r --example yolo -- --ver v11 --task obb --scale s --source images/do let options = Options::default() .with_yolo_version(YOLOVersion::V5) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR .with_yolo_task(YOLOTask::Classify) // YOLOTask: Classify, Detect, Pose, Segment, Obb + // .with_nc(80) + // .with_names(&COCO_CLASS_NAMES_80) .with_model("xxxx.onnx")?; ``` diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs index 8118fdc..5991a9b 100644 --- a/examples/yolo/main.rs +++ b/examples/yolo/main.rs @@ -3,7 +3,7 @@ use clap::Parser; use usls::{ models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOScale, YOLOTask, - YOLOVersion, COCO_KEYPOINTS_17, COCO_SKELETONS_16, + YOLOVersion, COCO_SKELETONS_16, }; #[derive(Parser, Clone)] @@ -42,7 +42,7 @@ pub struct Args { pub width: isize, /// Maximum input width - #[arg(long, default_value_t = 800)] + #[arg(long, default_value_t = 1024)] pub width_max: isize, /// Minimum input height @@ -54,7 +54,7 @@ pub struct Args { pub height: isize, /// Maximum input height - #[arg(long, default_value_t = 800)] + #[arg(long, default_value_t = 1024)] pub height_max: isize, /// Number of classes @@ -151,7 +151,7 @@ fn main() -> Result<()> { }) .with_nc(args.nc) // .with_names(&COCO_CLASS_NAMES_80) - .with_names2(&COCO_KEYPOINTS_17) + // .with_names2(&COCO_KEYPOINTS_17) .with_find_contours(!args.no_contours) // find contours or not .with_profile(args.profile); @@ -168,6 +168,7 @@ fn main() -> Result<()> { .with_skeletons(&COCO_SKELETONS_16) .without_masks(true) // No masks plotting when doing segment task. .with_bboxes_thickness(3) + .with_keypoints_name(false) // Enable keypoints names .with_saveout_subs(&["YOLO"]) .with_saveout(&saveout); diff --git a/src/models/yolo.rs b/src/models/yolo.rs index 931f244..546e35e 100644 --- a/src/models/yolo.rs +++ b/src/models/yolo.rs @@ -20,8 +20,8 @@ pub struct YOLO { confs: DynConf, kconfs: DynConf, iou: f32, - names: Option>, - names_kpt: Option>, + names: Vec, + names_kpt: Vec, task: YOLOTask, layout: YOLOPreds, find_contours: bool, @@ -96,42 +96,63 @@ impl Vision for YOLO { let task = task.unwrap_or(layout.task()); - // The number of classes & Class names - let mut names = options.names.or(Self::fetch_names(&engine)); - let nc = match options.nc { - Some(nc) => { - match &names { - None => names = Some((0..nc).map(|x| x.to_string()).collect::>()), - Some(names) => { - assert_eq!( - nc, + // Class names: user-defined.or(parsed) + let names_parsed = Self::fetch_names(&engine); + let names = match names_parsed { + Some(names_parsed) => match options.names { + Some(names) => { + if names.len() == names_parsed.len() { + Some(names) + } else { + anyhow::bail!( + "The lengths of parsed class names: {} and user-defined class names: {} do not match.", + names_parsed.len(), names.len(), - "The length of `nc` and `class names` is not equal." ); } } - nc - } - None => match &names { - Some(names) => names.len(), - None => panic!( - "Can not parse model without `nc` and `class names`. Try to make it explicit with `options.with_nc(80)`" - ), + None => Some(names_parsed), }, + None => options.names, }; - // Keypoints names - let names_kpt = options.names2; + // nc: names.len().or(options.nc) + let nc = match &names { + Some(names) => names.len(), + None => match options.nc { + Some(nc) => nc, + None => anyhow::bail!( + "Unable to obtain the number of classes. Please specify them explicitly using `options.with_nc(usize)` or `options.with_names(&[&str])`." + ), + } + }; - // The number of keypoints - let nk = engine - .try_fetch("kpt_shape") - .map(|kpt_string| { - let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap(); - let caps = re.captures(&kpt_string).unwrap(); - caps.get(1).unwrap().as_str().parse::().unwrap() - }) - .unwrap_or(0_usize); + // Class names + let names = match names { + None => Self::n2s(nc), + Some(names) => names, + }; + + // Keypoint names & nk + let (nk, names_kpt) = match Self::fetch_kpts(&engine) { + None => (0, vec![]), + Some(nk) => match options.names2 { + Some(names) => { + if names.len() == nk { + (nk, names) + } else { + anyhow::bail!( + "The lengths of user-defined keypoint names: {} and nk: {} do not match.", + names.len(), + nk, + ); + } + } + None => (nk, Self::n2s(nk)), + }, + }; + + // Confs & Iou let confs = DynConf::new(&options.confs, nc); let kconfs = DynConf::new(&options.kconfs, nk); let iou = options.iou.unwrap_or(0.45); @@ -139,6 +160,7 @@ impl Vision for YOLO { // Summary tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version); + // dry run engine.dry_run()?; Ok(Self { @@ -218,10 +240,8 @@ impl Vision for YOLO { slice_clss.into_owned() }; let mut probs = Prob::default().with_probs(&x.into_raw_vec_and_offset().0); - if let Some(names) = &self.names { - probs = - probs.with_names(&names.iter().map(|x| x.as_str()).collect::>()); - } + probs = probs + .with_names(&self.names.iter().map(|x| x.as_str()).collect::>()); return Some(y.with_probs(&probs)); } @@ -324,9 +344,7 @@ impl Vision for YOLO { ) .with_confidence(confidence) .with_id(class_id as isize); - if let Some(names) = &self.names { - mbr = mbr.with_name(&names[class_id]); - } + mbr = mbr.with_name(&self.names[class_id]); (None, Some(mbr)) } @@ -336,9 +354,7 @@ impl Vision for YOLO { .with_confidence(confidence) .with_id(class_id as isize) .with_id_born(i as isize); - if let Some(names) = &self.names { - bbox = bbox.with_name(&names[class_id]); - } + bbox = bbox.with_name(&self.names[class_id]); (Some(bbox), None) } @@ -393,9 +409,7 @@ impl Vision for YOLO { ky.max(0.0f32).min(image_height), ); - if let Some(names) = &self.names_kpt { - kpt = kpt.with_name(&names[i]); - } + kpt = kpt.with_name(&self.names_kpt[i]); kpt } }) @@ -504,16 +518,16 @@ impl Vision for YOLO { } impl YOLO { - pub fn batch(&self) -> isize { - self.batch.opt() as _ + pub fn batch(&self) -> usize { + self.batch.opt() } - pub fn width(&self) -> isize { - self.width.opt() as _ + pub fn width(&self) -> usize { + self.width.opt() } - pub fn height(&self) -> isize { - self.height.opt() as _ + pub fn height(&self) -> usize { + self.height.opt() } pub fn version(&self) -> Option<&YOLOVersion> { @@ -540,4 +554,16 @@ impl YOLO { names_ }) } + + fn fetch_kpts(engine: &OrtEngine) -> Option { + engine.try_fetch("kpt_shape").map(|s| { + let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap(); + let caps = re.captures(&s).unwrap(); + caps.get(1).unwrap().as_str().parse::().unwrap() + }) + } + + fn n2s(n: usize) -> Vec { + (0..n).map(|x| format!("# {}", x)).collect::>() + } }