diff --git a/Cargo.toml b/Cargo.toml index 799bd06..5beb404 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "usls" -version = "0.0.14" +version = "0.0.15" edition = "2021" description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." repository = "https://github.com/jamjamjon/usls" @@ -22,7 +22,6 @@ dirs = { version = "5.0.1" } ureq = { version = "2.9.1", default-features = true, features = [ "socks-proxy", ] } -walkdir = { version = "2.5.0" } # TODO: remove tokenizers = { version = "0.15.2" } rayon = "1.10.0" indicatif = "0.17.8" diff --git a/README.md b/README.md index a3ba257..9090e04 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ - **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10) - **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) - **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569) -- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World) +- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242)
Click to expand Supported Models @@ -71,6 +71,9 @@ | [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ | | [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | | | [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Body Part Segmentation | [demo](examples/sapiens) | ✅ | ✅ | | | +| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | | | + +
diff --git a/examples/blip/main.rs b/examples/blip/main.rs index 656d5c2..b0cc929 100644 --- a/examples/blip/main.rs +++ b/examples/blip/main.rs @@ -10,7 +10,7 @@ fn main() -> Result<(), Box> { // textual let options_textual = Options::default() .with_model("blip/textual-base.onnx")? - // .with_tokenizer("blip/tokenizer.json")? + .with_tokenizer("blip/tokenizer.json")? .with_i00((1, 1, 4).into()) // input_id: batch .with_i01((1, 1, 4).into()) // input_id: seq_len .with_i10((1, 1, 4).into()) // attention_mask: batch diff --git a/examples/dataloader/main.rs b/examples/dataloader/main.rs index a0b62db..40484f2 100644 --- a/examples/dataloader/main.rs +++ b/examples/dataloader/main.rs @@ -18,26 +18,24 @@ fn main() -> anyhow::Result<()> { // build dataloader let dl = DataLoader::new( + // "images/bus.jpg", // remote image + // "../images", // image folder + // "../demo.mp4", // local video + // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // remote video + // "rtsp://admin:xyz@192.168.2.217:554/h265/ch1/", // rtsp h264 stream "./assets/bus.jpg", // local image - // "images/bus.jpg", // remote image - // "../images", // image folder - // "../demo.mp4", // local video - // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // remote video - // "rtsp://admin:xyz@192.168.2.217:554/h265/ch1/", // rtsp h264 stream )? .with_batch(1) - .with_progress_bar(true) - .with_bound(100) .build()?; - // // build annotator + // build annotator let annotator = Annotator::new() .with_bboxes_thickness(4) .with_saveout("YOLO-DataLoader"); // run for (xs, _) in dl { - // std::thread::sleep(std::time::Duration::from_millis(1000)); + // std::thread::sleep(std::time::Duration::from_millis(100)); let ys = model.forward(&xs, false)?; annotator.annotate(&xs, &ys); } diff --git a/examples/florence2/main.rs b/examples/florence2/main.rs new file mode 100644 index 0000000..546ebb9 --- /dev/null +++ b/examples/florence2/main.rs @@ -0,0 +1,252 @@ +use usls::{models::Florence2, Annotator, DataLoader, Options, Task}; + +fn main() -> Result<(), Box> { + // vision encoder + let options_vision_encoder = Options::default() + .with_model("florence2/base-vision-encoder.onnx")? + .with_i00((1, 2, 4).into()) + .with_i02((512, 768, 800).into()) + .with_i03((512, 768, 800).into()) + .with_profile(false) + .with_cuda(0); + + // text embed + let options_text_embed = Options::default() + .with_model("florence2/base-embed-tokens.onnx")? + .with_i00((1, 2, 4).into()) + .with_i01((1, 2, 20).into()) // seq_length + .with_tokenizer("florence2/tokenizer.json")? + .with_profile(false); + + // transformer encoder + let options_encoder = Options::default() + .with_model("florence2/base-encoder.onnx")? + .with_i00((1, 2, 4).into()) + .with_i01((1, 2, 300).into()) // encoder_sequence_length + .with_i10((1, 2, 4).into()) + .with_i11((1, 2, 300).into()) // encoder_sequence_length + .with_profile(false); + + // transformer decoder + let options_decoder = Options::default() + .with_model("florence2/base-decoder.onnx")? + .with_i00((1, 2, 4).into()) + .with_i01((1, 2, 300).into()) // encoder_sequence_length + .with_i10((1, 2, 4).into()) + .with_i11((1, 2, 300).into()) // encoder_sequence_length + .with_i20((1, 2, 4).into()) + .with_i21((1, 2, 300).into()) // encoder_sequence_length + .with_profile(false); + + // transformer decoder merged + let options_decoder_merged = Options::default() + .with_model("florence2/base-decoder-merged.onnx")? + // encoder_attention_mask + .with_i00((1, 2, 4).into()) + .with_i01((1, 2, 300).into()) // encoder_sequence_length + // encoder_hidden_states + .with_i10((1, 2, 4).into()) + .with_i11((1, 2, 300).into()) // encoder_sequence_length + // inputs_embeds + .with_i20((1, 2, 4).into()) + .with_i21((1, 2, 300).into()) // encoder_sequence_length + // past_key_values.0.decoder.key + .with_i30((1, 2, 4).into()) + .with_i32_((1, 2, 1).into()) + // past_key_values.0.decoder.value + .with_i40((1, 2, 4).into()) + .with_i42((1, 2, 1).into()) + // past_key_values.0.encoder.key + .with_i50((1, 2, 4).into()) + .with_i52((1, 2, 1).into()) + // past_key_values.0.decoder.value + .with_i60((1, 2, 4).into()) + .with_i62((1, 2, 1).into()) + // past_key_values.1.decoder.key + .with_i70((1, 2, 4).into()) + .with_i72((1, 2, 1).into()) + // past_key_values.1.decoder.value + .with_i80((1, 2, 4).into()) + .with_i82((1, 2, 1).into()) + // past_key_values.1.encoder.key + .with_i90((1, 2, 4).into()) + .with_i92((1, 2, 1).into()) + // past_key_values.1.decoder.value + .with_i100((1, 2, 4).into()) + .with_i102((1, 2, 1).into()) + // past_key_values.2.decoder.key + .with_i110((1, 2, 4).into()) + .with_i112((1, 2, 1).into()) + // past_key_values.2.decoder.value + .with_i120((1, 2, 4).into()) + .with_i122((1, 2, 1).into()) + // past_key_values.2.encoder.key + .with_i130((1, 2, 4).into()) + .with_i132((1, 2, 1).into()) + // past_key_values.2.decoder.value + .with_i140((1, 2, 4).into()) + .with_i142((1, 2, 1).into()) + // past_key_values.3.decoder.key + .with_i150((1, 2, 4).into()) + .with_i152((1, 2, 1).into()) + // past_key_values.3.decoder.value + .with_i160((1, 2, 4).into()) + .with_i162((1, 2, 1).into()) + // past_key_values.3.encoder.key + .with_i170((1, 2, 4).into()) + .with_i172((1, 2, 1).into()) + // past_key_values.3.decoder.value + .with_i180((1, 2, 4).into()) + .with_i182((1, 2, 1).into()) + // past_key_values.4.decoder.key + .with_i190((1, 2, 4).into()) + .with_i192((1, 2, 1).into()) + // past_key_values.4.decoder.value + .with_i200((1, 2, 4).into()) + .with_i202((1, 2, 1).into()) + // past_key_values.4.encoder.key + .with_i210((1, 2, 4).into()) + .with_i212((1, 2, 1).into()) + // past_key_values.4.decoder.value + .with_i220((1, 2, 4).into()) + .with_i222((1, 2, 1).into()) + // past_key_values.5.decoder.key + .with_i230((1, 2, 4).into()) + .with_i232((1, 2, 1).into()) + // past_key_values.5.decoder.value + .with_i240((1, 2, 4).into()) + .with_i242((1, 2, 1).into()) + // past_key_values.5.encoder.key + .with_i250((1, 2, 4).into()) + .with_i252((1, 2, 1).into()) + // past_key_values.5.decoder.value + .with_i260((1, 2, 4).into()) + .with_i262((1, 2, 1).into()) + //use_cache_branch + .with_i270((1, 2, 1).into()) + .with_profile(false); + + // build model + let mut model = Florence2::new( + options_vision_encoder, + options_text_embed, + options_encoder, + options_decoder, + options_decoder_merged, + )?; + + // load images + let xs = [ + // DataLoader::try_read("florence2/car.jpg")?, // for testing region-related tasks + DataLoader::try_read("florence2/car.jpg")?, + // DataLoader::try_read("images/db.png")?, + DataLoader::try_read("assets/bus.jpg")?, + ]; + + // region-related tasks + let quantizer = usls::Quantizer::default(); + // let coords = [449., 270., 556., 372.]; // wheel + let coords = [31., 156., 581., 373.]; // car + let (width_car, height_car) = (xs[0].width(), xs[0].height()); + let quantized_coords = quantizer.quantize(&coords, (width_car as _, height_car as _)); + + // run with tasks + let ys = model.run_with_tasks( + &xs, + &[ + // w/ inputs + Task::Caption(0), + Task::Caption(1), + Task::Caption(2), + Task::Ocr, + Task::OcrWithRegion, + Task::RegionProposal, + Task::ObjectDetection, + Task::DenseRegionCaption, + // w/o inputs + Task::OpenSetDetection("a vehicle".into()), + Task::CaptionToPhraseGrounding( + "A vehicle with two wheels parked in front of a building.".into(), + ), + Task::ReferringExpressionSegmentation("a vehicle".into()), + Task::RegionToSegmentation( + quantized_coords[0], + quantized_coords[1], + quantized_coords[2], + quantized_coords[3], + ), + Task::RegionToCategory( + quantized_coords[0], + quantized_coords[1], + quantized_coords[2], + quantized_coords[3], + ), + Task::RegionToDescription( + quantized_coords[0], + quantized_coords[1], + quantized_coords[2], + quantized_coords[3], + ), + ], + )?; + + // annotator + let annotator = Annotator::new() + .without_bboxes_conf(true) + .with_bboxes_thickness(3) + .with_saveout_subs(&["Florence2"]); + for (task, ys_) in ys.iter() { + match task { + Task::Caption(_) + | Task::Ocr + | Task::RegionToCategory(..) + | Task::RegionToDescription(..) => { + println!("Task: {:?}\n{:?}\n", task, ys_) + } + Task::DenseRegionCaption => { + let annotator = annotator.clone().with_saveout("Dense-Region-Caption"); + annotator.annotate(&xs, ys_); + } + Task::RegionProposal => { + let annotator = annotator + .clone() + .without_bboxes_name(false) + .with_saveout("Region-Proposal"); + + annotator.annotate(&xs, ys_); + } + Task::ObjectDetection => { + let annotator = annotator.clone().with_saveout("Object-Detection"); + annotator.annotate(&xs, ys_); + } + Task::OpenSetDetection(_) => { + let annotator = annotator.clone().with_saveout("Open-Set-Detection"); + annotator.annotate(&xs, ys_); + } + Task::CaptionToPhraseGrounding(_) => { + let annotator = annotator + .clone() + .with_saveout("Caption-To-Phrase-Grounding"); + annotator.annotate(&xs, ys_); + } + Task::ReferringExpressionSegmentation(_) => { + let annotator = annotator + .clone() + .with_saveout("Referring-Expression-Segmentation"); + annotator.annotate(&xs, ys_); + } + Task::RegionToSegmentation(..) => { + let annotator = annotator.clone().with_saveout("Region-To-Segmentation"); + annotator.annotate(&xs, ys_); + } + Task::OcrWithRegion => { + let annotator = annotator.clone().with_saveout("Ocr-With-Region"); + annotator.annotate(&xs, ys_); + } + + _ => (), + } + } + + Ok(()) +} diff --git a/src/core/annotator.rs b/src/core/annotator.rs index 75a667c..af8c14d 100644 --- a/src/core/annotator.rs +++ b/src/core/annotator.rs @@ -2,19 +2,21 @@ use crate::{ colormap256, string_now, Bbox, Dir, Hub, Keypoint, Mask, Mbr, Polygon, Prob, CHECK_MARK, CROSS_MARK, Y, }; -use ab_glyph::{FontVec, PxScale}; +use ab_glyph::{FontArc, PxScale}; use anyhow::Result; use image::{DynamicImage, GenericImage, Rgba, RgbaImage}; use imageproc::map::map_colors; /// Annotator for struct `Y` -// #[derive(Debug)] +#[derive(Clone)] pub struct Annotator { - font: FontVec, + // TODO: Add lifetime + font: FontArc, _scale: f32, // Cope with ab_glyph & imageproc=0.24.0 scale_dy: f32, saveout_base: String, saveout: Option, + saveout_subs: Vec, decimal_places: usize, // About mbrs @@ -72,6 +74,7 @@ impl Default for Annotator { scale_dy: 28., polygons_alpha: 179, saveout: None, + saveout_subs: vec![], saveout_base: String::from("runs"), decimal_places: 4, without_bboxes: false, @@ -323,6 +326,11 @@ impl Annotator { self } + pub fn with_saveout_subs(mut self, xs: &[&str]) -> Self { + self.saveout_subs = xs.iter().map(|x| x.to_string()).collect::>(); + self + } + pub fn with_font(mut self, path: &str) -> Result { self.font = Self::load_font(Some(path))?; Ok(self) @@ -330,11 +338,23 @@ impl Annotator { /// Create folders for saving annotated results. e.g., `./runs/xxx` pub fn saveout(&self) -> Result { - let subs = match &self.saveout { - Some(x) => vec![self.saveout_base.as_str(), x.as_str()], - None => vec![self.saveout_base.as_str()], - }; + let mut subs = vec![self.saveout_base.as_str()]; + if let Some(saveout) = &self.saveout { + // add subs + if !self.saveout_subs.is_empty() { + let xs = self + .saveout_subs + .iter() + .map(|x| x.as_str()) + .collect::>(); + subs.extend(xs); + } + + // add filename + subs.push(saveout); + } + // mkdir even no filename specified Dir::Currnet.raw_path_with_subs(&subs) } @@ -345,7 +365,7 @@ impl Annotator { /// Plot images and return plotted images(RGBA8) pub fn plot(&self, imgs: &[DynamicImage], ys: &[Y]) -> Result> { - let span = tracing::span!(tracing::Level::INFO, "YOLO-new"); + let span = tracing::span!(tracing::Level::INFO, "Annotator-plot"); let _guard = span.enter(); let mut vs: Vec = Vec::new(); @@ -719,7 +739,7 @@ impl Annotator { let top = if y > text_h as f32 { (y.round() as u32 - text_h) as i32 } else { - (text_h - y.round() as u32) as i32 + 0 }; let mut left = x as i32; if left + text_w as i32 > img.width() as i32 { @@ -749,13 +769,13 @@ impl Annotator { } /// Load custom font - fn load_font(path: Option<&str>) -> Result { + fn load_font(path: Option<&str>) -> Result { let path_font = match path { None => Hub::new()?.fetch("fonts/Arial.ttf")?.commit()?, Some(p) => p.into(), }; let buffer = std::fs::read(path_font)?; - Ok(FontVec::try_from_vec(buffer.to_owned())?) + Ok(FontArc::try_from_vec(buffer.to_owned())?) } /// Pick color from pallette @@ -765,6 +785,7 @@ impl Annotator { /// Color pallette fn color_palette() -> [(u8, u8, u8, u8); 20] { + // TODO: more colors [ (0, 255, 127, 255), // spring green (255, 105, 180, 255), // hot pink diff --git a/src/core/dataloader.rs b/src/core/dataloader.rs index c4dfc51..502ce92 100644 --- a/src/core/dataloader.rs +++ b/src/core/dataloader.rs @@ -35,7 +35,7 @@ impl Iterator for DataLoaderIterator { None => { progress_bar.set_prefix(" Iterated"); progress_bar.set_style( - indicatif::ProgressStyle::with_template(crate::PROGRESS_BAR_STYLE_GREEN) + indicatif::ProgressStyle::with_template(crate::PROGRESS_BAR_STYLE_FINISH_2) .unwrap(), ); progress_bar.finish(); @@ -56,7 +56,7 @@ impl IntoIterator for DataLoader { self.nf / self.batch_size as u64, " Iterating", Some(&format!("{:?}", self.media_type)), - crate::PROGRESS_BAR_STYLE_CYAN, + crate::PROGRESS_BAR_STYLE_CYAN_2, ) .ok() } else { @@ -365,7 +365,12 @@ impl DataLoader { let saveout = Dir::Currnet .raw_path_with_subs(subs)? .join(format!("{}.mp4", string_now("-"))); - let pb = build_progress_bar(paths.len() as u64, " Converting", saveout.to_str(),"{prefix:.cyan.bold} {msg} |{bar}| ({percent_precise}%, {human_pos}/{human_len}, {per_sec})")?; + let pb = build_progress_bar( + paths.len() as u64, + " Converting", + Some(&format!("{:?}", MediaType::Video(Location::Local))), + crate::PROGRESS_BAR_STYLE_CYAN_2, + )?; // loop for path in paths { @@ -397,10 +402,10 @@ impl DataLoader { } // update - pb.set_prefix(" Downloaded"); pb.set_prefix(" Converted"); + pb.set_message(saveout.to_str().unwrap_or_default().to_string()); pb.set_style(ProgressStyle::with_template( - "{prefix:.green.bold} {msg} in {elapsed}", + crate::PROGRESS_BAR_STYLE_FINISH_4, )?); pb.finish(); diff --git a/src/core/hub.rs b/src/core/hub.rs index f7b49e5..c3f3e90 100644 --- a/src/core/hub.rs +++ b/src/core/hub.rs @@ -152,65 +152,59 @@ impl Hub { match p.exists() { true => self.path = p, false => { - // check local cache 1st - let p_cache = self.cache.with_file_name(s); - if p_cache.exists() { - self.path = p_cache; - } else { - // check remote list then - match s.split_once('/') { - Some((tag, file_name)) => { - // Extract tag and file from input string - self.tag = Some(tag.to_string()); - self.file_name = Some(file_name.to_string()); - - // Check if releases are already loaded in memory - if self.releases.is_none() { - self.releases = Some(self.connect_remote()?); + // check remote + match s.split_once('/') { + Some((tag, file_name)) => { + // Extract tag and file from input string + self.tag = Some(tag.to_string()); + self.file_name = Some(file_name.to_string()); + + // Check if releases are already loaded in memory + if self.releases.is_none() { + self.releases = Some(self.connect_remote()?); + } + + if let Some(releases) = &self.releases { + // Validate the tag + let tags: Vec<&str> = + releases.iter().map(|x| x.tag_name.as_str()).collect(); + if !tags.contains(&tag) { + anyhow::bail!( + "Hub tag '{}' not found in releases. Available tags: {:?}", + tag, + tags + ); } - if let Some(releases) = &self.releases { - // Validate the tag - let tags: Vec<&str> = - releases.iter().map(|x| x.tag_name.as_str()).collect(); - if !tags.contains(&tag) { + // Validate the file + if let Some(release) = releases.iter().find(|r| r.tag_name == tag) { + let files: Vec<&str> = + release.assets.iter().map(|x| x.name.as_str()).collect(); + if !files.contains(&file_name) { anyhow::bail!( - "Hub tag '{}' not found in releases. Available tags: {:?}", - tag, - tags - ); - } - - // Validate the file - if let Some(release) = releases.iter().find(|r| r.tag_name == tag) { - let files: Vec<&str> = - release.assets.iter().map(|x| x.name.as_str()).collect(); - if !files.contains(&file_name) { - anyhow::bail!( "Hub file '{}' not found in tag '{}'. Available files: {:?}", file_name, tag, files ); - } else { - for f_ in release.assets.iter() { - if f_.name.as_str() == file_name { - self.url = Some(f_.browser_download_url.clone()); - self.file_size = Some(f_.size); - - break; - } + } else { + for f_ in release.assets.iter() { + if f_.name.as_str() == file_name { + self.url = Some(f_.browser_download_url.clone()); + self.file_size = Some(f_.size); + + break; } } } - self.path = self.to.path_with_subs(&[tag])?.join(file_name); } + self.path = self.to.path_with_subs(&[tag])?.join(file_name); } - _ => anyhow::bail!( + } + _ => anyhow::bail!( "Download failed due to invalid format. Expected: /, got: {}", s ), - } } } } @@ -286,7 +280,7 @@ impl Hub { } pub fn connect_remote(&mut self) -> Result> { - let span = tracing::span!(tracing::Level::INFO, "OrtEngine-run"); + let span = tracing::span!(tracing::Level::INFO, "Hub-connect_remote"); let _guard = span.enter(); let should_download = if !self.cache.exists() { @@ -416,7 +410,7 @@ impl Hub { // update pb.set_prefix(" Downloaded"); pb.set_style(ProgressStyle::with_template( - "{prefix:.green.bold} {msg} ({binary_total_bytes}) in {elapsed}", + crate::PROGRESS_BAR_STYLE_FINISH_3, )?); pb.finish(); diff --git a/src/core/min_opt_max.rs b/src/core/min_opt_max.rs index 71c4970..803d4d2 100644 --- a/src/core/min_opt_max.rs +++ b/src/core/min_opt_max.rs @@ -50,4 +50,14 @@ impl MinOptMax { max: opt, } } + + pub fn update(&mut self, opt: isize) { + self.opt = opt; + if self.min > opt { + self.min = opt; + } + if self.max < opt { + self.max = opt; + } + } } diff --git a/src/core/ops.rs b/src/core/ops.rs index 10d3a74..7b60fb3 100644 --- a/src/core/ops.rs +++ b/src/core/ops.rs @@ -7,7 +7,7 @@ use fast_image_resize::{ FilterType, ResizeAlg, ResizeOptions, Resizer, }; use image::{DynamicImage, GenericImageView}; -use ndarray::{s, Array, Array3, Axis, IntoDimension, IxDyn}; +use ndarray::{concatenate, s, Array, Array3, Axis, IntoDimension, IxDyn}; use rayon::prelude::*; pub enum Ops<'a> { @@ -114,6 +114,14 @@ impl Ops<'_> { Self::permute(x, &[0, 2, 3, 1]) } + pub fn concatenate( + x: &Array, + y: &Array, + d: usize, + ) -> Result> { + Ok(concatenate(Axis(d), &[x.view(), y.view()])?) + } + pub fn insert_axis(x: Array, d: usize) -> Result> { if x.shape().len() < d { anyhow::bail!( diff --git a/src/core/options.rs b/src/core/options.rs index 2c6cc41..3bd8691 100644 --- a/src/core/options.rs +++ b/src/core/options.rs @@ -4,13 +4,14 @@ use anyhow::Result; use crate::{ models::{SamKind, SapiensTask, YOLOPreds, YOLOTask, YOLOVersion}, - Device, Hub, MinOptMax, + Device, Hub, MinOptMax, Task, }; /// Options for building models #[derive(Debug, Clone)] pub struct Options { pub onnx_path: String, + pub task: Task, pub device: Device, pub profile: bool, pub num_dry_run: usize, @@ -62,6 +63,126 @@ pub struct Options { pub i73: Option, pub i74: Option, pub i75: Option, + pub i80: Option, + pub i81: Option, + pub i82: Option, + pub i83: Option, + pub i84: Option, + pub i85: Option, + pub i90: Option, + pub i91: Option, + pub i92: Option, + pub i93: Option, + pub i94: Option, + pub i95: Option, + pub i100: Option, + pub i101: Option, + pub i102: Option, + pub i103: Option, + pub i104: Option, + pub i105: Option, + pub i110: Option, + pub i111: Option, + pub i112: Option, + pub i113: Option, + pub i114: Option, + pub i115: Option, + pub i120: Option, + pub i121: Option, + pub i122: Option, + pub i123: Option, + pub i124: Option, + pub i125: Option, + pub i130: Option, + pub i131: Option, + pub i132: Option, + pub i133: Option, + pub i134: Option, + pub i135: Option, + pub i140: Option, + pub i141: Option, + pub i142: Option, + pub i143: Option, + pub i144: Option, + pub i145: Option, + pub i150: Option, + pub i151: Option, + pub i152: Option, + pub i153: Option, + pub i154: Option, + pub i155: Option, + pub i160: Option, + pub i161: Option, + pub i162: Option, + pub i163: Option, + pub i164: Option, + pub i165: Option, + pub i170: Option, + pub i171: Option, + pub i172: Option, + pub i173: Option, + pub i174: Option, + pub i175: Option, + pub i180: Option, + pub i181: Option, + pub i182: Option, + pub i183: Option, + pub i184: Option, + pub i185: Option, + pub i190: Option, + pub i191: Option, + pub i192: Option, + pub i193: Option, + pub i194: Option, + pub i195: Option, + pub i200: Option, + pub i201: Option, + pub i202: Option, + pub i203: Option, + pub i204: Option, + pub i205: Option, + pub i210: Option, + pub i211: Option, + pub i212: Option, + pub i213: Option, + pub i214: Option, + pub i215: Option, + pub i220: Option, + pub i221: Option, + pub i222: Option, + pub i223: Option, + pub i224: Option, + pub i225: Option, + pub i230: Option, + pub i231: Option, + pub i232: Option, + pub i233: Option, + pub i234: Option, + pub i235: Option, + pub i240: Option, + pub i241: Option, + pub i242: Option, + pub i243: Option, + pub i244: Option, + pub i245: Option, + pub i250: Option, + pub i251: Option, + pub i252: Option, + pub i253: Option, + pub i254: Option, + pub i255: Option, + pub i260: Option, + pub i261: Option, + pub i262: Option, + pub i263: Option, + pub i264: Option, + pub i265: Option, + pub i270: Option, + pub i271: Option, + pub i272: Option, + pub i273: Option, + pub i274: Option, + pub i275: Option, // trt related pub trt_engine_cache_enable: bool, pub trt_int8_enable: bool, @@ -149,6 +270,126 @@ impl Default for Options { i73: None, i74: None, i75: None, + i80: None, + i81: None, + i82: None, + i83: None, + i84: None, + i85: None, + i90: None, + i91: None, + i92: None, + i93: None, + i94: None, + i95: None, + i100: None, + i101: None, + i102: None, + i103: None, + i104: None, + i105: None, + i110: None, + i111: None, + i112: None, + i113: None, + i114: None, + i115: None, + i120: None, + i121: None, + i122: None, + i123: None, + i124: None, + i125: None, + i130: None, + i131: None, + i132: None, + i133: None, + i134: None, + i135: None, + i140: None, + i141: None, + i142: None, + i143: None, + i144: None, + i145: None, + i150: None, + i151: None, + i152: None, + i153: None, + i154: None, + i155: None, + i160: None, + i161: None, + i162: None, + i163: None, + i164: None, + i165: None, + i170: None, + i171: None, + i172: None, + i173: None, + i174: None, + i175: None, + i180: None, + i181: None, + i182: None, + i183: None, + i184: None, + i185: None, + i190: None, + i191: None, + i192: None, + i193: None, + i194: None, + i195: None, + i200: None, + i201: None, + i202: None, + i203: None, + i204: None, + i205: None, + i210: None, + i211: None, + i212: None, + i213: None, + i214: None, + i215: None, + i220: None, + i221: None, + i222: None, + i223: None, + i224: None, + i225: None, + i230: None, + i231: None, + i232: None, + i233: None, + i234: None, + i235: None, + i240: None, + i241: None, + i242: None, + i243: None, + i244: None, + i245: None, + i250: None, + i251: None, + i252: None, + i253: None, + i254: None, + i255: None, + i260: None, + i261: None, + i262: None, + i263: None, + i264: None, + i265: None, + i270: None, + i271: None, + i272: None, + i273: None, + i274: None, + i275: None, trt_engine_cache_enable: true, trt_int8_enable: false, trt_fp16_enable: false, @@ -176,6 +417,7 @@ impl Default for Options { sam_kind: None, use_low_res_mask: None, sapiens_task: None, + task: Task::Untitled, } } } @@ -185,6 +427,11 @@ impl Options { Default::default() } + pub fn with_task(mut self, task: Task) -> Self { + self.task = task; + self + } + pub fn with_model(mut self, onnx_path: &str) -> Result { self.onnx_path = Hub::new()?.fetch(onnx_path)?.commit()?; Ok(self) @@ -579,4 +826,603 @@ impl Options { self.i75 = Some(x); self } + + pub fn with_i80(mut self, x: MinOptMax) -> Self { + self.i80 = Some(x); + self + } + + pub fn with_i81(mut self, x: MinOptMax) -> Self { + self.i81 = Some(x); + self + } + + pub fn with_i82(mut self, x: MinOptMax) -> Self { + self.i82 = Some(x); + self + } + + pub fn with_i83(mut self, x: MinOptMax) -> Self { + self.i83 = Some(x); + self + } + + pub fn with_i84(mut self, x: MinOptMax) -> Self { + self.i84 = Some(x); + self + } + + pub fn with_i85(mut self, x: MinOptMax) -> Self { + self.i85 = Some(x); + self + } + + pub fn with_i90(mut self, x: MinOptMax) -> Self { + self.i90 = Some(x); + self + } + + pub fn with_i91(mut self, x: MinOptMax) -> Self { + self.i91 = Some(x); + self + } + + pub fn with_i92(mut self, x: MinOptMax) -> Self { + self.i92 = Some(x); + self + } + + pub fn with_i93(mut self, x: MinOptMax) -> Self { + self.i93 = Some(x); + self + } + + pub fn with_i94(mut self, x: MinOptMax) -> Self { + self.i94 = Some(x); + self + } + + pub fn with_i95(mut self, x: MinOptMax) -> Self { + self.i95 = Some(x); + self + } + + pub fn with_i100(mut self, x: MinOptMax) -> Self { + self.i100 = Some(x); + self + } + + pub fn with_i101(mut self, x: MinOptMax) -> Self { + self.i101 = Some(x); + self + } + + pub fn with_i102(mut self, x: MinOptMax) -> Self { + self.i102 = Some(x); + self + } + + pub fn with_i103(mut self, x: MinOptMax) -> Self { + self.i103 = Some(x); + self + } + + pub fn with_i104(mut self, x: MinOptMax) -> Self { + self.i104 = Some(x); + self + } + + pub fn with_i105(mut self, x: MinOptMax) -> Self { + self.i105 = Some(x); + self + } + + pub fn with_i110(mut self, x: MinOptMax) -> Self { + self.i110 = Some(x); + self + } + + pub fn with_i111(mut self, x: MinOptMax) -> Self { + self.i111 = Some(x); + self + } + + pub fn with_i112(mut self, x: MinOptMax) -> Self { + self.i112 = Some(x); + self + } + + pub fn with_i113(mut self, x: MinOptMax) -> Self { + self.i113 = Some(x); + self + } + + pub fn with_i114(mut self, x: MinOptMax) -> Self { + self.i114 = Some(x); + self + } + + pub fn with_i115(mut self, x: MinOptMax) -> Self { + self.i115 = Some(x); + self + } + + pub fn with_i120(mut self, x: MinOptMax) -> Self { + self.i120 = Some(x); + self + } + + pub fn with_i121(mut self, x: MinOptMax) -> Self { + self.i121 = Some(x); + self + } + + pub fn with_i122(mut self, x: MinOptMax) -> Self { + self.i122 = Some(x); + self + } + + pub fn with_i123(mut self, x: MinOptMax) -> Self { + self.i123 = Some(x); + self + } + + pub fn with_i124(mut self, x: MinOptMax) -> Self { + self.i124 = Some(x); + self + } + + pub fn with_i125(mut self, x: MinOptMax) -> Self { + self.i125 = Some(x); + self + } + + pub fn with_i130(mut self, x: MinOptMax) -> Self { + self.i130 = Some(x); + self + } + + pub fn with_i131(mut self, x: MinOptMax) -> Self { + self.i131 = Some(x); + self + } + + pub fn with_i132(mut self, x: MinOptMax) -> Self { + self.i132 = Some(x); + self + } + + pub fn with_i133(mut self, x: MinOptMax) -> Self { + self.i133 = Some(x); + self + } + + pub fn with_i134(mut self, x: MinOptMax) -> Self { + self.i134 = Some(x); + self + } + + pub fn with_i135(mut self, x: MinOptMax) -> Self { + self.i135 = Some(x); + self + } + + pub fn with_i140(mut self, x: MinOptMax) -> Self { + self.i140 = Some(x); + self + } + + pub fn with_i141(mut self, x: MinOptMax) -> Self { + self.i141 = Some(x); + self + } + + pub fn with_i142(mut self, x: MinOptMax) -> Self { + self.i142 = Some(x); + self + } + + pub fn with_i143(mut self, x: MinOptMax) -> Self { + self.i143 = Some(x); + self + } + + pub fn with_i144(mut self, x: MinOptMax) -> Self { + self.i144 = Some(x); + self + } + + pub fn with_i145(mut self, x: MinOptMax) -> Self { + self.i145 = Some(x); + self + } + + pub fn with_i150(mut self, x: MinOptMax) -> Self { + self.i150 = Some(x); + self + } + + pub fn with_i151(mut self, x: MinOptMax) -> Self { + self.i151 = Some(x); + self + } + + pub fn with_i152(mut self, x: MinOptMax) -> Self { + self.i152 = Some(x); + self + } + + pub fn with_i153(mut self, x: MinOptMax) -> Self { + self.i153 = Some(x); + self + } + + pub fn with_i154(mut self, x: MinOptMax) -> Self { + self.i154 = Some(x); + self + } + + pub fn with_i155(mut self, x: MinOptMax) -> Self { + self.i155 = Some(x); + self + } + + pub fn with_i160(mut self, x: MinOptMax) -> Self { + self.i160 = Some(x); + self + } + + pub fn with_i161(mut self, x: MinOptMax) -> Self { + self.i161 = Some(x); + self + } + + pub fn with_i162(mut self, x: MinOptMax) -> Self { + self.i162 = Some(x); + self + } + + pub fn with_i163(mut self, x: MinOptMax) -> Self { + self.i163 = Some(x); + self + } + + pub fn with_i164(mut self, x: MinOptMax) -> Self { + self.i164 = Some(x); + self + } + + pub fn with_i165(mut self, x: MinOptMax) -> Self { + self.i165 = Some(x); + self + } + + pub fn with_i170(mut self, x: MinOptMax) -> Self { + self.i170 = Some(x); + self + } + + pub fn with_i171(mut self, x: MinOptMax) -> Self { + self.i171 = Some(x); + self + } + + pub fn with_i172(mut self, x: MinOptMax) -> Self { + self.i172 = Some(x); + self + } + + pub fn with_i173(mut self, x: MinOptMax) -> Self { + self.i173 = Some(x); + self + } + + pub fn with_i174(mut self, x: MinOptMax) -> Self { + self.i174 = Some(x); + self + } + + pub fn with_i175(mut self, x: MinOptMax) -> Self { + self.i175 = Some(x); + self + } + + pub fn with_i180(mut self, x: MinOptMax) -> Self { + self.i180 = Some(x); + self + } + + pub fn with_i181(mut self, x: MinOptMax) -> Self { + self.i181 = Some(x); + self + } + + pub fn with_i182(mut self, x: MinOptMax) -> Self { + self.i182 = Some(x); + self + } + + pub fn with_i183(mut self, x: MinOptMax) -> Self { + self.i183 = Some(x); + self + } + + pub fn with_i184(mut self, x: MinOptMax) -> Self { + self.i184 = Some(x); + self + } + + pub fn with_i185(mut self, x: MinOptMax) -> Self { + self.i185 = Some(x); + self + } + + pub fn with_i190(mut self, x: MinOptMax) -> Self { + self.i190 = Some(x); + self + } + + pub fn with_i191(mut self, x: MinOptMax) -> Self { + self.i191 = Some(x); + self + } + + pub fn with_i192(mut self, x: MinOptMax) -> Self { + self.i192 = Some(x); + self + } + + pub fn with_i193(mut self, x: MinOptMax) -> Self { + self.i193 = Some(x); + self + } + + pub fn with_i194(mut self, x: MinOptMax) -> Self { + self.i194 = Some(x); + self + } + + pub fn with_i195(mut self, x: MinOptMax) -> Self { + self.i195 = Some(x); + self + } + + pub fn with_i200(mut self, x: MinOptMax) -> Self { + self.i200 = Some(x); + self + } + + pub fn with_i201(mut self, x: MinOptMax) -> Self { + self.i201 = Some(x); + self + } + + pub fn with_i202(mut self, x: MinOptMax) -> Self { + self.i202 = Some(x); + self + } + + pub fn with_i203(mut self, x: MinOptMax) -> Self { + self.i203 = Some(x); + self + } + + pub fn with_i204(mut self, x: MinOptMax) -> Self { + self.i204 = Some(x); + self + } + + pub fn with_i205(mut self, x: MinOptMax) -> Self { + self.i205 = Some(x); + self + } + + pub fn with_i210(mut self, x: MinOptMax) -> Self { + self.i210 = Some(x); + self + } + + pub fn with_i211(mut self, x: MinOptMax) -> Self { + self.i211 = Some(x); + self + } + + pub fn with_i212(mut self, x: MinOptMax) -> Self { + self.i212 = Some(x); + self + } + + pub fn with_i213(mut self, x: MinOptMax) -> Self { + self.i213 = Some(x); + self + } + + pub fn with_i214(mut self, x: MinOptMax) -> Self { + self.i214 = Some(x); + self + } + + pub fn with_i215(mut self, x: MinOptMax) -> Self { + self.i215 = Some(x); + self + } + + pub fn with_i220(mut self, x: MinOptMax) -> Self { + self.i220 = Some(x); + self + } + + pub fn with_i221(mut self, x: MinOptMax) -> Self { + self.i221 = Some(x); + self + } + + pub fn with_i222(mut self, x: MinOptMax) -> Self { + self.i222 = Some(x); + self + } + + pub fn with_i223(mut self, x: MinOptMax) -> Self { + self.i223 = Some(x); + self + } + + pub fn with_i224(mut self, x: MinOptMax) -> Self { + self.i224 = Some(x); + self + } + + pub fn with_i225(mut self, x: MinOptMax) -> Self { + self.i225 = Some(x); + self + } + + pub fn with_i230(mut self, x: MinOptMax) -> Self { + self.i230 = Some(x); + self + } + + pub fn with_i231(mut self, x: MinOptMax) -> Self { + self.i231 = Some(x); + self + } + + pub fn with_i232(mut self, x: MinOptMax) -> Self { + self.i232 = Some(x); + self + } + + pub fn with_i233(mut self, x: MinOptMax) -> Self { + self.i233 = Some(x); + self + } + + pub fn with_i234(mut self, x: MinOptMax) -> Self { + self.i234 = Some(x); + self + } + + pub fn with_i235(mut self, x: MinOptMax) -> Self { + self.i235 = Some(x); + self + } + + pub fn with_i240(mut self, x: MinOptMax) -> Self { + self.i240 = Some(x); + self + } + + pub fn with_i241(mut self, x: MinOptMax) -> Self { + self.i241 = Some(x); + self + } + + pub fn with_i242(mut self, x: MinOptMax) -> Self { + self.i242 = Some(x); + self + } + + pub fn with_i243(mut self, x: MinOptMax) -> Self { + self.i243 = Some(x); + self + } + + pub fn with_i244(mut self, x: MinOptMax) -> Self { + self.i244 = Some(x); + self + } + + pub fn with_i245(mut self, x: MinOptMax) -> Self { + self.i245 = Some(x); + self + } + + pub fn with_i250(mut self, x: MinOptMax) -> Self { + self.i250 = Some(x); + self + } + + pub fn with_i251(mut self, x: MinOptMax) -> Self { + self.i251 = Some(x); + self + } + + pub fn with_i252(mut self, x: MinOptMax) -> Self { + self.i252 = Some(x); + self + } + + pub fn with_i253(mut self, x: MinOptMax) -> Self { + self.i253 = Some(x); + self + } + + pub fn with_i254(mut self, x: MinOptMax) -> Self { + self.i254 = Some(x); + self + } + + pub fn with_i255(mut self, x: MinOptMax) -> Self { + self.i255 = Some(x); + self + } + pub fn with_i260(mut self, x: MinOptMax) -> Self { + self.i260 = Some(x); + self + } + + pub fn with_i261(mut self, x: MinOptMax) -> Self { + self.i261 = Some(x); + self + } + + pub fn with_i262(mut self, x: MinOptMax) -> Self { + self.i262 = Some(x); + self + } + + pub fn with_i263(mut self, x: MinOptMax) -> Self { + self.i263 = Some(x); + self + } + + pub fn with_i264(mut self, x: MinOptMax) -> Self { + self.i264 = Some(x); + self + } + + pub fn with_i265(mut self, x: MinOptMax) -> Self { + self.i265 = Some(x); + self + } + + pub fn with_i270(mut self, x: MinOptMax) -> Self { + self.i270 = Some(x); + self + } + + pub fn with_i271(mut self, x: MinOptMax) -> Self { + self.i271 = Some(x); + self + } + + pub fn with_i272(mut self, x: MinOptMax) -> Self { + self.i272 = Some(x); + self + } + + pub fn with_i273(mut self, x: MinOptMax) -> Self { + self.i273 = Some(x); + self + } + + pub fn with_i274(mut self, x: MinOptMax) -> Self { + self.i274 = Some(x); + self + } + + pub fn with_i275(mut self, x: MinOptMax) -> Self { + self.i275 = Some(x); + self + } } diff --git a/src/core/ort_engine.rs b/src/core/ort_engine.rs index dd2994c..255a0ae 100644 --- a/src/core/ort_engine.rs +++ b/src/core/ort_engine.rs @@ -23,6 +23,7 @@ pub struct OrtTensorAttr { /// ONNXRuntime Backend #[derive(Debug)] pub struct OrtEngine { + name: String, session: Session, device: Device, inputs_minoptmax: Vec>, @@ -129,6 +130,126 @@ impl OrtEngine { (7, 3) => Self::_set_ixx(x, &config.i73, i, ii).unwrap_or(x_default), (7, 4) => Self::_set_ixx(x, &config.i74, i, ii).unwrap_or(x_default), (7, 5) => Self::_set_ixx(x, &config.i75, i, ii).unwrap_or(x_default), + (8, 0) => Self::_set_ixx(x, &config.i80, i, ii).unwrap_or(x_default), + (8, 1) => Self::_set_ixx(x, &config.i81, i, ii).unwrap_or(x_default), + (8, 2) => Self::_set_ixx(x, &config.i82, i, ii).unwrap_or(x_default), + (8, 3) => Self::_set_ixx(x, &config.i83, i, ii).unwrap_or(x_default), + (8, 4) => Self::_set_ixx(x, &config.i84, i, ii).unwrap_or(x_default), + (8, 5) => Self::_set_ixx(x, &config.i85, i, ii).unwrap_or(x_default), + (9, 0) => Self::_set_ixx(x, &config.i90, i, ii).unwrap_or(x_default), + (9, 1) => Self::_set_ixx(x, &config.i91, i, ii).unwrap_or(x_default), + (9, 2) => Self::_set_ixx(x, &config.i92, i, ii).unwrap_or(x_default), + (9, 3) => Self::_set_ixx(x, &config.i93, i, ii).unwrap_or(x_default), + (9, 4) => Self::_set_ixx(x, &config.i94, i, ii).unwrap_or(x_default), + (9, 5) => Self::_set_ixx(x, &config.i95, i, ii).unwrap_or(x_default), + (10, 0) => Self::_set_ixx(x, &config.i100, i, ii).unwrap_or(x_default), + (10, 1) => Self::_set_ixx(x, &config.i101, i, ii).unwrap_or(x_default), + (10, 2) => Self::_set_ixx(x, &config.i102, i, ii).unwrap_or(x_default), + (10, 3) => Self::_set_ixx(x, &config.i103, i, ii).unwrap_or(x_default), + (10, 4) => Self::_set_ixx(x, &config.i104, i, ii).unwrap_or(x_default), + (10, 5) => Self::_set_ixx(x, &config.i105, i, ii).unwrap_or(x_default), + (11, 0) => Self::_set_ixx(x, &config.i110, i, ii).unwrap_or(x_default), + (11, 1) => Self::_set_ixx(x, &config.i111, i, ii).unwrap_or(x_default), + (11, 2) => Self::_set_ixx(x, &config.i112, i, ii).unwrap_or(x_default), + (11, 3) => Self::_set_ixx(x, &config.i113, i, ii).unwrap_or(x_default), + (11, 4) => Self::_set_ixx(x, &config.i114, i, ii).unwrap_or(x_default), + (11, 5) => Self::_set_ixx(x, &config.i115, i, ii).unwrap_or(x_default), + (12, 0) => Self::_set_ixx(x, &config.i120, i, ii).unwrap_or(x_default), + (12, 1) => Self::_set_ixx(x, &config.i121, i, ii).unwrap_or(x_default), + (12, 2) => Self::_set_ixx(x, &config.i122, i, ii).unwrap_or(x_default), + (12, 3) => Self::_set_ixx(x, &config.i123, i, ii).unwrap_or(x_default), + (12, 4) => Self::_set_ixx(x, &config.i124, i, ii).unwrap_or(x_default), + (12, 5) => Self::_set_ixx(x, &config.i125, i, ii).unwrap_or(x_default), + (13, 0) => Self::_set_ixx(x, &config.i130, i, ii).unwrap_or(x_default), + (13, 1) => Self::_set_ixx(x, &config.i131, i, ii).unwrap_or(x_default), + (13, 2) => Self::_set_ixx(x, &config.i132, i, ii).unwrap_or(x_default), + (13, 3) => Self::_set_ixx(x, &config.i133, i, ii).unwrap_or(x_default), + (13, 4) => Self::_set_ixx(x, &config.i134, i, ii).unwrap_or(x_default), + (13, 5) => Self::_set_ixx(x, &config.i135, i, ii).unwrap_or(x_default), + (14, 0) => Self::_set_ixx(x, &config.i140, i, ii).unwrap_or(x_default), + (14, 1) => Self::_set_ixx(x, &config.i141, i, ii).unwrap_or(x_default), + (14, 2) => Self::_set_ixx(x, &config.i142, i, ii).unwrap_or(x_default), + (14, 3) => Self::_set_ixx(x, &config.i143, i, ii).unwrap_or(x_default), + (14, 4) => Self::_set_ixx(x, &config.i144, i, ii).unwrap_or(x_default), + (14, 5) => Self::_set_ixx(x, &config.i145, i, ii).unwrap_or(x_default), + (15, 0) => Self::_set_ixx(x, &config.i150, i, ii).unwrap_or(x_default), + (15, 1) => Self::_set_ixx(x, &config.i151, i, ii).unwrap_or(x_default), + (15, 2) => Self::_set_ixx(x, &config.i152, i, ii).unwrap_or(x_default), + (15, 3) => Self::_set_ixx(x, &config.i153, i, ii).unwrap_or(x_default), + (15, 4) => Self::_set_ixx(x, &config.i154, i, ii).unwrap_or(x_default), + (15, 5) => Self::_set_ixx(x, &config.i155, i, ii).unwrap_or(x_default), + (16, 0) => Self::_set_ixx(x, &config.i160, i, ii).unwrap_or(x_default), + (16, 1) => Self::_set_ixx(x, &config.i161, i, ii).unwrap_or(x_default), + (16, 2) => Self::_set_ixx(x, &config.i162, i, ii).unwrap_or(x_default), + (16, 3) => Self::_set_ixx(x, &config.i163, i, ii).unwrap_or(x_default), + (16, 4) => Self::_set_ixx(x, &config.i164, i, ii).unwrap_or(x_default), + (16, 5) => Self::_set_ixx(x, &config.i165, i, ii).unwrap_or(x_default), + (17, 0) => Self::_set_ixx(x, &config.i170, i, ii).unwrap_or(x_default), + (17, 1) => Self::_set_ixx(x, &config.i171, i, ii).unwrap_or(x_default), + (17, 2) => Self::_set_ixx(x, &config.i172, i, ii).unwrap_or(x_default), + (17, 3) => Self::_set_ixx(x, &config.i173, i, ii).unwrap_or(x_default), + (17, 4) => Self::_set_ixx(x, &config.i174, i, ii).unwrap_or(x_default), + (17, 5) => Self::_set_ixx(x, &config.i175, i, ii).unwrap_or(x_default), + (18, 0) => Self::_set_ixx(x, &config.i180, i, ii).unwrap_or(x_default), + (18, 1) => Self::_set_ixx(x, &config.i181, i, ii).unwrap_or(x_default), + (18, 2) => Self::_set_ixx(x, &config.i182, i, ii).unwrap_or(x_default), + (18, 3) => Self::_set_ixx(x, &config.i183, i, ii).unwrap_or(x_default), + (18, 4) => Self::_set_ixx(x, &config.i184, i, ii).unwrap_or(x_default), + (18, 5) => Self::_set_ixx(x, &config.i185, i, ii).unwrap_or(x_default), + (19, 0) => Self::_set_ixx(x, &config.i190, i, ii).unwrap_or(x_default), + (19, 1) => Self::_set_ixx(x, &config.i191, i, ii).unwrap_or(x_default), + (19, 2) => Self::_set_ixx(x, &config.i192, i, ii).unwrap_or(x_default), + (19, 3) => Self::_set_ixx(x, &config.i193, i, ii).unwrap_or(x_default), + (19, 4) => Self::_set_ixx(x, &config.i194, i, ii).unwrap_or(x_default), + (19, 5) => Self::_set_ixx(x, &config.i195, i, ii).unwrap_or(x_default), + (20, 0) => Self::_set_ixx(x, &config.i200, i, ii).unwrap_or(x_default), + (20, 1) => Self::_set_ixx(x, &config.i201, i, ii).unwrap_or(x_default), + (20, 2) => Self::_set_ixx(x, &config.i202, i, ii).unwrap_or(x_default), + (20, 3) => Self::_set_ixx(x, &config.i203, i, ii).unwrap_or(x_default), + (20, 4) => Self::_set_ixx(x, &config.i204, i, ii).unwrap_or(x_default), + (20, 5) => Self::_set_ixx(x, &config.i205, i, ii).unwrap_or(x_default), + (21, 0) => Self::_set_ixx(x, &config.i210, i, ii).unwrap_or(x_default), + (21, 1) => Self::_set_ixx(x, &config.i211, i, ii).unwrap_or(x_default), + (21, 2) => Self::_set_ixx(x, &config.i212, i, ii).unwrap_or(x_default), + (21, 3) => Self::_set_ixx(x, &config.i213, i, ii).unwrap_or(x_default), + (21, 4) => Self::_set_ixx(x, &config.i214, i, ii).unwrap_or(x_default), + (21, 5) => Self::_set_ixx(x, &config.i215, i, ii).unwrap_or(x_default), + (22, 0) => Self::_set_ixx(x, &config.i220, i, ii).unwrap_or(x_default), + (22, 1) => Self::_set_ixx(x, &config.i221, i, ii).unwrap_or(x_default), + (22, 2) => Self::_set_ixx(x, &config.i222, i, ii).unwrap_or(x_default), + (22, 3) => Self::_set_ixx(x, &config.i223, i, ii).unwrap_or(x_default), + (22, 4) => Self::_set_ixx(x, &config.i224, i, ii).unwrap_or(x_default), + (22, 5) => Self::_set_ixx(x, &config.i225, i, ii).unwrap_or(x_default), + (23, 0) => Self::_set_ixx(x, &config.i230, i, ii).unwrap_or(x_default), + (23, 1) => Self::_set_ixx(x, &config.i231, i, ii).unwrap_or(x_default), + (23, 2) => Self::_set_ixx(x, &config.i232, i, ii).unwrap_or(x_default), + (23, 3) => Self::_set_ixx(x, &config.i233, i, ii).unwrap_or(x_default), + (23, 4) => Self::_set_ixx(x, &config.i234, i, ii).unwrap_or(x_default), + (23, 5) => Self::_set_ixx(x, &config.i235, i, ii).unwrap_or(x_default), + (24, 0) => Self::_set_ixx(x, &config.i240, i, ii).unwrap_or(x_default), + (24, 1) => Self::_set_ixx(x, &config.i241, i, ii).unwrap_or(x_default), + (24, 2) => Self::_set_ixx(x, &config.i242, i, ii).unwrap_or(x_default), + (24, 3) => Self::_set_ixx(x, &config.i243, i, ii).unwrap_or(x_default), + (24, 4) => Self::_set_ixx(x, &config.i244, i, ii).unwrap_or(x_default), + (24, 5) => Self::_set_ixx(x, &config.i245, i, ii).unwrap_or(x_default), + (25, 0) => Self::_set_ixx(x, &config.i250, i, ii).unwrap_or(x_default), + (25, 1) => Self::_set_ixx(x, &config.i251, i, ii).unwrap_or(x_default), + (25, 2) => Self::_set_ixx(x, &config.i252, i, ii).unwrap_or(x_default), + (25, 3) => Self::_set_ixx(x, &config.i253, i, ii).unwrap_or(x_default), + (25, 4) => Self::_set_ixx(x, &config.i254, i, ii).unwrap_or(x_default), + (25, 5) => Self::_set_ixx(x, &config.i255, i, ii).unwrap_or(x_default), + (26, 0) => Self::_set_ixx(x, &config.i260, i, ii).unwrap_or(x_default), + (26, 1) => Self::_set_ixx(x, &config.i261, i, ii).unwrap_or(x_default), + (26, 2) => Self::_set_ixx(x, &config.i262, i, ii).unwrap_or(x_default), + (26, 3) => Self::_set_ixx(x, &config.i263, i, ii).unwrap_or(x_default), + (26, 4) => Self::_set_ixx(x, &config.i264, i, ii).unwrap_or(x_default), + (26, 5) => Self::_set_ixx(x, &config.i265, i, ii).unwrap_or(x_default), + (27, 0) => Self::_set_ixx(x, &config.i270, i, ii).unwrap_or(x_default), + (27, 1) => Self::_set_ixx(x, &config.i271, i, ii).unwrap_or(x_default), + (27, 2) => Self::_set_ixx(x, &config.i272, i, ii).unwrap_or(x_default), + (27, 3) => Self::_set_ixx(x, &config.i273, i, ii).unwrap_or(x_default), + (27, 4) => Self::_set_ixx(x, &config.i274, i, ii).unwrap_or(x_default), + (27, 5) => Self::_set_ixx(x, &config.i275, i, ii).unwrap_or(x_default), _ => todo!(), }; v_.push(x); @@ -181,6 +302,7 @@ impl OrtEngine { ); Ok(Self { + name: config.onnx_path.to_owned(), session, device, inputs_minoptmax, @@ -204,7 +326,7 @@ impl OrtEngine { fp16_enable: bool, engine_cache_enable: bool, ) -> Result<()> { - let span = tracing::span!(tracing::Level::INFO, "OrtEngine-new"); + let span = tracing::span!(tracing::Level::INFO, "OrtEngine-build_trt"); let _guard = span.enter(); // auto generate shapes @@ -284,11 +406,16 @@ impl OrtEngine { pub fn dry_run(&mut self) -> Result<()> { if self.num_dry_run > 0 { // pb + let name = std::path::Path::new(&self.name); let pb = build_progress_bar( self.num_dry_run as u64, " DryRun", - Some(&format!("{:?}", self.device)), - crate::PROGRESS_BAR_STYLE_CYAN, + Some( + name.file_name() + .and_then(|x| x.to_str()) + .unwrap_or_default(), + ), + crate::PROGRESS_BAR_STYLE_CYAN_2, )?; // dummy inputs @@ -311,8 +438,16 @@ impl OrtEngine { self.ts.clear(); // update + let name = std::path::Path::new(&self.name); + pb.set_message(format!( + "{} on {:?}", + name.file_name() + .and_then(|x| x.to_str()) + .unwrap_or_default(), + self.device, + )); pb.set_style(indicatif::ProgressStyle::with_template( - crate::PROGRESS_BAR_STYLE_GREEN, + crate::PROGRESS_BAR_STYLE_FINISH, )?); pb.finish(); } @@ -357,6 +492,7 @@ impl OrtEngine { // inference let t_run = std::time::Instant::now(); let outputs = self.session.run(&xs_[..])?; + let t_run = t_run.elapsed(); self.ts.add_or_push(1, t_run); @@ -370,21 +506,32 @@ impl OrtEngine { .zip(self.outputs_attrs.names.iter()) { let y = &outputs[name.as_str()]; + let y_ = match &dtype { - TensorElementType::Float32 => y.try_extract_tensor::()?.view().into_owned(), - TensorElementType::Float16 => y - .try_extract_tensor::()? - .view() - .mapv(f16::to_f32) - .into_owned(), - TensorElementType::Int64 => y - .try_extract_tensor::()? - .view() - .to_owned() - .mapv(|x| x as f32) - .into_owned(), + TensorElementType::Float32 => match y.try_extract_tensor::() { + Err(err) => { + tracing::error!("Error: {:?}. Output name: {:?}", err, name); + Array::zeros(0).into_dyn() + } + Ok(x) => x.view().into_owned(), + }, + TensorElementType::Float16 => match y.try_extract_tensor::() { + Err(err) => { + tracing::error!("Error: {:?}. Output name: {:?}", err, name); + Array::zeros(0).into_dyn() + } + Ok(x) => x.view().mapv(f16::to_f32).into_owned(), + }, + TensorElementType::Int64 => match y.try_extract_tensor::() { + Err(err) => { + tracing::error!("Error: {:?}. Output name: {:?}", err, name); + Array::zeros(0).into_dyn() + } + Ok(x) => x.view().to_owned().mapv(|x| x as f32).into_owned(), + }, _ => todo!(), }; + ys.push_kv(name.as_str(), X::from(y_))?; } let t_post = t_post.elapsed(); diff --git a/src/core/task.rs b/src/core/task.rs index b85a00d..8090625 100644 --- a/src/core/task.rs +++ b/src/core/task.rs @@ -1,27 +1,186 @@ -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Ord, Eq, PartialOrd, PartialEq)] pub enum Task { - // vision + Untitled, + + /// Image classification task. + /// Input: image + /// Output: a label representing the class of the image ImageClassification, + + /// Multi-label image tagging task. + /// Input: image + /// Output: multiple labels representing different categories in the image + ImageTagging, + + /// Image captioning task, generating descriptions with different levels of detail. + /// Input: image + /// Output: a text description, `u8` represents the level of detail: + /// 0 for brief, 1 for detailed, 2 for more detailed + Caption(u8), + + /// Region proposal task, detecting all objects in the image. + /// Input: image + /// Output: bounding boxes (bboxes) + RegionProposal, + + /// Object detection task, detecting all objects in the image. + /// Input: image + /// Output: bounding boxes (bboxes), class labels, and optional scores for the detected objects ObjectDetection, + + /// Open set detection task, detecting and classifying objects in an image, with the ability to handle unseen or unknown objects. + /// Input: image + /// Output: bounding boxes, class labels (including an "unknown" category for unfamiliar objects), and detection scores + /// Open set detection task, with String query + OpenSetDetection(String), + + /// Task for generating brief descriptions of dense regions in the image. + /// Input: image + /// Output: bounding boxes (bboxes), brief phrase labels, and optional scores for detected regions + DenseRegionCaption, + + /// Keypoint detection task, detecting keypoints in an image. + /// This can include human body parts (e.g., hands, feet, joints) or other objects. + /// Input: image + /// Output: coordinates of detected keypoints KeypointsDetection, - RegisonProposal, - PoseEstimation, + + /// Semantic segmentation task, segmenting the image into different semantic regions. + /// Input: image + /// Output: per-pixel class labels indicating object or background SemanticSegmentation, + + /// Instance segmentation task, detecting and segmenting individual object instances. + /// Input: image + /// Output: pixel masks for each object instance InstanceSegmentation, + + /// Depth estimation task, predicting the distance of each pixel from the camera. + /// Input: image + /// Output: a depth map where each pixel has a depth value DepthEstimation, + + /// Surface normal prediction task, predicting the surface normal vector for each pixel. + /// Input: image + /// Output: a normal map where each pixel has a surface normal vector SurfaceNormalPrediction, - Image2ImageGeneration, + + /// Image-to-image generation task, transforming one image into another. + /// Input: image + /// Output: a generated image + ImageToImageGeneration, + + /// Text-to-image generation task, generating an image based on a text description. + /// Input: text + /// Output: a generated image + TextToImageGeneration, + + /// Inpainting task, filling in missing or corrupted parts of an image. + /// Input: image with missing or corrupted regions + /// Output: a complete image with the missing parts filled in Inpainting, + + /// Super-resolution task, enhancing the resolution of an image. + /// Input: low-resolution image + /// Output: high-resolution image SuperResolution, + + /// Image denoising task, removing noise from an image. + /// Input: noisy image + /// Output: denoised image Denoising, - // vl - Tagging, - Captioning, - DetailedCaptioning, - MoreDetailedCaptioning, - PhraseGrounding, - Vqa, + /// Phrase grounding task, finding the region in an image corresponding to a text description. + /// Input: image and text + /// Output: image region and the corresponding phrase + /// caption to phrase grounding + CaptionToPhraseGrounding(String), + + /// Referring expression segmentation task, segmenting objects in the image based on a text description. + /// Input: image and referring expression + /// Output: a segmentation mask for the object referred to by the text + ReferringExpressionSegmentation(String), + + /// Region-to-segmentation task, similar to combining object detection with segmentation (e.g., YOLO + SAM). + /// Input: image and region proposals + /// Output: segmentation masks for the regions + /// Region, bbox: top-left, bottom-right + RegionToSegmentation(usize, usize, usize, usize), + + /// Region-to-category classification task, classifying the object in a given region of the image. + /// Input: image and region + /// Output: class label for the region + /// Region, bbox: top-left, bottom-right + RegionToCategory(usize, usize, usize, usize), + + /// Region-to-description task, generating a detailed description for a given region in the image. + /// Input: image and region + /// Output: a detailed textual description for the region + /// Region, bbox: top-left, bottom-right + RegionToDescription(usize, usize, usize, usize), + + /// Visual question answering (VQA) task, answering questions related to an image. + /// Input: image and question text + /// Output: the answer to the question + Vqa(String), + + /// Optical character recognition (OCR) task, recognizing text in an image. + /// Input: image + /// Output: recognized text Ocr, - Text2ImageGeneration, + + /// OCR task with region information, recognizing text and returning its location in the image. + /// Input: image + /// Output: recognized text and its bounding box in the image + OcrWithRegion, +} + +impl Task { + pub fn prompt_for_florence2(&self) -> anyhow::Result { + let prompt = match self { + Self::Untitled => anyhow::bail!("No task specified."), + Self::Caption(0) => "What does the image describe?".to_string(), + Self::Caption(1) => "Describe in detail what is shown in the image.".to_string(), + Self::Caption(2) => "Describe with a paragraph what is shown in the image.".to_string(), + Self::Ocr => "What is the text in the image?".to_string(), + Self::OcrWithRegion => "What is the text in the image, with regions?".to_string(), + Self::ObjectDetection => { + "Locate the objects with category name in the image.".to_string() + } + Self::DenseRegionCaption => { + "Locate the objects in the image, with their descriptions.".to_string() + } + Self::RegionProposal => "Locate the region proposals in the image.".to_string(), + Self::OpenSetDetection(text) => { + format!("Locate {} in the image.", text) + } + Self::CaptionToPhraseGrounding(text) => { + format!("Locate the phrases in the caption: {}", text) + } + Self::ReferringExpressionSegmentation(text) => { + format!("Locate {} in the image with mask", text) + } + Self::RegionToSegmentation(x0, y0, x1, y1) => { + format!( + "What is the polygon mask of region ", + x0, y0, x1, y1 + ) + } + Self::RegionToCategory(x0, y0, x1, y1) => { + format!( + "What is the region ?", + x0, y0, x1, y1 + ) + } + Self::RegionToDescription(x0, y0, x1, y1) => { + format!( + "What does the region describe?", + x0, y0, x1, y1 + ) + } + _ => anyhow::bail!("Unsupported task."), + }; + + Ok(prompt) + } } diff --git a/src/core/tokenizer_stream.rs b/src/core/tokenizer_stream.rs index 5fb8025..495d69a 100644 --- a/src/core/tokenizer_stream.rs +++ b/src/core/tokenizer_stream.rs @@ -1,4 +1,4 @@ -// https://github.com/huggingface/candle/blob/2a8679509eb55232b37378442c4366343f6dcb11/candle-examples/src/token_output_stream.rs#L5 +// TODO: refactor use anyhow::Result; /// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a @@ -32,7 +32,6 @@ impl TokenizerStream { } } - // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 pub fn next_token(&mut self, token: u32) -> Result> { let prev_text = if self.tokens.is_empty() { String::new() @@ -42,7 +41,7 @@ impl TokenizerStream { }; self.tokens.push(token); let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { + if text.len() > prev_text.len() { let text = text.split_at(prev_text.len()); self.prev_index = self.current_index; self.current_index = self.tokens.len(); diff --git a/src/core/vision.rs b/src/core/vision.rs index a4c8931..f78bc18 100644 --- a/src/core/vision.rs +++ b/src/core/vision.rs @@ -25,7 +25,7 @@ pub trait Vision: Sized { /// Executes the full pipeline. fn forward(&mut self, xs: &[Self::Input], profile: bool) -> anyhow::Result> { - let span = tracing::span!(tracing::Level::INFO, "DataLoader-new"); + let span = tracing::span!(tracing::Level::INFO, "Vision-forward"); let _guard = span.enter(); let t_pre = std::time::Instant::now(); diff --git a/src/core/x.rs b/src/core/x.rs index 433479d..3a245e2 100644 --- a/src/core/x.rs +++ b/src/core/x.rs @@ -98,6 +98,11 @@ impl X { Ok(self) } + pub fn concatenate(mut self, other: &Self, d: usize) -> Result { + self.0 = Ops::concatenate(&self.0, other, d)?; + Ok(self) + } + pub fn dims(&self) -> &[usize] { self.0.shape() } diff --git a/src/lib.rs b/src/lib.rs index fd63333..ce9d586 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ //! - **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10) //! - **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) //! - **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569) -//! - **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World) +//! - **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242) //! //! # Examples //! diff --git a/src/models/florence2.rs b/src/models/florence2.rs new file mode 100644 index 0000000..66eb497 --- /dev/null +++ b/src/models/florence2.rs @@ -0,0 +1,459 @@ +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Axis}; +use rayon::prelude::*; +use std::collections::BTreeMap; +use tokenizers::Tokenizer; + +use crate::{ + build_progress_bar, Bbox, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, Polygon, + Quantizer, Task, Xs, X, Y, +}; + +#[derive(Debug)] +pub struct Florence2 { + pub vision_encoder: OrtEngine, + pub text_embed: OrtEngine, + pub encoder: OrtEngine, + pub decoder: OrtEngine, + pub decoder_merged: OrtEngine, + height: MinOptMax, + width: MinOptMax, + batch: MinOptMax, + tokenizer: Tokenizer, + max_length: usize, + quantizer: Quantizer, +} + +impl Florence2 { + pub fn new( + options_vision_encoder: Options, + options_text_embed: Options, + options_encoder: Options, + options_decoder: Options, + options_decoder_merged: Options, + ) -> Result { + let mut vision_encoder = OrtEngine::new(&options_vision_encoder)?; + let mut text_embed = OrtEngine::new(&options_text_embed)?; + let mut encoder = OrtEngine::new(&options_encoder)?; + let mut decoder = OrtEngine::new(&options_decoder)?; + let mut decoder_merged = OrtEngine::new(&options_decoder_merged)?; + let (batch, height, width) = ( + vision_encoder.batch().to_owned(), + vision_encoder.height().to_owned(), + vision_encoder.width().to_owned(), + ); + let tokenizer = options_text_embed + .tokenizer + .ok_or(anyhow::anyhow!("No tokenizer file found"))?; + let tokenizer = match Tokenizer::from_file(tokenizer) { + Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err), + Ok(x) => x, + }; + + let quantizer = Quantizer::default(); + let max_length = 1024; + + // dry run + vision_encoder.dry_run()?; + text_embed.dry_run()?; + encoder.dry_run()?; + decoder.dry_run()?; + decoder_merged.dry_run()?; + + Ok(Self { + vision_encoder, + text_embed, + encoder, + decoder, + decoder_merged, + height, + width, + batch, + tokenizer, + max_length, + quantizer, + }) + } + + pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { + let xs_ = X::apply(&[ + Ops::Resize( + xs, + self.height.opt as u32, + self.width.opt as u32, + "Bilinear", + ), + Ops::Normalize(0., 255.), + Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3), + Ops::Nhwc2nchw, + ])?; + let ys = self.vision_encoder.run(Xs::from(xs_))?[0].to_owned(); + Ok(ys) + } + + pub fn run_with_tasks( + &mut self, + xs: &[DynamicImage], + tasks: &[Task], + ) -> Result>> { + let mut ys: BTreeMap> = BTreeMap::new(); + + // encode images + let image_embeddings = self.encode_images(xs)?; + + // note: the length of xs is not always equal to batch size + self.batch.update(xs.len() as isize); + + // build pb + let pb = build_progress_bar( + tasks.len() as u64, + " Working On", + None, + crate::PROGRESS_BAR_STYLE_CYAN_2, + )?; + + // tasks + for task in tasks.iter() { + pb.inc(1); + pb.set_message(format!("{:?}", task)); + + // construct prompt and encode + let input_ids = self + .encode_prompt(task)? + .insert_axis(0)? + .repeat(0, self.batch())?; + let text_embeddings = self.text_embed.run(Xs::from(input_ids))?[0].clone(); + + // run + let texts = self.run_batch(&image_embeddings, &text_embeddings)?; + + // tasks iteration + let ys_task = (0..self.batch()) + .into_par_iter() + .map(|batch| { + // image size + let image_width = xs[batch].width() as usize; + let image_height = xs[batch].height() as usize; + + // texts cleanup + let text = texts[batch] + .as_str() + .replace("", "") + .replace("", "") + .replace("", ""); + + // postprocess + let mut y = Y::default(); + if let Task::Caption(_) | Task::Ocr = task { + y = y.with_texts(&[text]); + } else { + let elems = Self::loc_parse(&text)?; + match task { + Task::RegionToCategory(..) | Task::RegionToDescription(..) => { + let text = elems[0][0].clone(); + y = y.with_texts(&[text]); + } + Task::ObjectDetection + | Task::OpenSetDetection(_) + | Task::DenseRegionCaption + | Task::CaptionToPhraseGrounding(_) => { + let y_bboxes: Vec = elems + .par_iter() + .enumerate() + .flat_map(|(i, elem)| { + Self::process_bboxes( + &elem[1..], + &self.quantizer, + image_width, + image_height, + Some((&elem[0], i)), + ) + }) + .collect(); + y = y.with_bboxes(&y_bboxes); + } + Task::RegionProposal => { + let y_bboxes: Vec = Self::process_bboxes( + &elems[0], + &self.quantizer, + image_width, + image_height, + None, + ); + y = y.with_bboxes(&y_bboxes); + } + Task::ReferringExpressionSegmentation(_) + | Task::RegionToSegmentation(..) => { + let points = Self::process_polygons( + &elems[0], + &self.quantizer, + image_width, + image_height, + ); + y = y.with_polygons(&[Polygon::default() + .with_points(&points) + .with_id(0)]); + } + Task::OcrWithRegion => { + let y_polygons: Vec = elems + .par_iter() + .enumerate() + .map(|(i, elem)| { + let points = Self::process_polygons( + &elem[1..], + &self.quantizer, + image_width, + image_height, + ); + Polygon::default() + .with_name(&elem[0]) + .with_points(&points) + .with_id(i as _) + }) + .collect(); + y = y.with_polygons(&y_polygons); + } + _ => anyhow::bail!("Unsupported Florence2 task."), + }; + } + Ok(y) + }) + .collect::>>()?; + + ys.insert(task.clone(), ys_task); + } + + // update pb + pb.set_prefix(" Completed"); + pb.set_message("Florence2 tasks"); + pb.set_style(indicatif::ProgressStyle::with_template( + crate::PROGRESS_BAR_STYLE_FINISH_2, + )?); + pb.finish(); + + Ok(ys) + } + + fn run_batch(&mut self, image_embeddings: &X, text_embeddings: &X) -> Result> { + // concate image_embeddings and prompt embeddings + let inputs_embeds = image_embeddings.clone().concatenate(text_embeddings, 1)?; + let attention_mask = X::ones(&[self.batch(), inputs_embeds.dims()[1]]); + + // encoder + let last_hidden_state = self.encoder.run(Xs::from(vec![ + attention_mask.clone(), + inputs_embeds.clone(), + ]))?[0] + .clone(); + + // decoder + let inputs_embeds = inputs_embeds.slice(s![.., -1.., ..]); + let inputs_embeds = X::from(inputs_embeds.to_owned().into_dyn()); + let mut decoder_outputs = self.decoder.run(Xs::from(vec![ + attention_mask.clone(), + last_hidden_state.clone(), + inputs_embeds, + ]))?; + + let encoder_k0 = decoder_outputs[3].clone(); + let encoder_v0 = decoder_outputs[4].clone(); + let encoder_k1 = decoder_outputs[7].clone(); + let encoder_v1 = decoder_outputs[8].clone(); + let encoder_k2 = decoder_outputs[11].clone(); + let encoder_v2 = decoder_outputs[12].clone(); + let encoder_k3 = decoder_outputs[15].clone(); + let encoder_v3 = decoder_outputs[16].clone(); + let encoder_k4 = decoder_outputs[19].clone(); + let encoder_v4 = decoder_outputs[20].clone(); + let encoder_k5 = decoder_outputs[23].clone(); + let encoder_v5 = decoder_outputs[24].clone(); + + let mut generated_tokens: Vec> = vec![vec![]; self.batch()]; + let mut finished = vec![false; self.batch()]; + + // save last batch tokens + let mut last_tokens: Vec = vec![0.; self.batch()]; + let mut logits_sampler = LogitsSampler::new(); + + // generate + for _ in 0..self.max_length { + let logits = &decoder_outputs["logits"]; + let decoder_k0 = &decoder_outputs[1]; + let decoder_v0 = &decoder_outputs[2]; + let decoder_k1 = &decoder_outputs[5]; + let decoder_v1 = &decoder_outputs[6]; + let decoder_k2 = &decoder_outputs[9]; + let decoder_v2 = &decoder_outputs[10]; + let decoder_k3 = &decoder_outputs[13]; + let decoder_v3 = &decoder_outputs[14]; + let decoder_k4 = &decoder_outputs[17]; + let decoder_v4 = &decoder_outputs[18]; + let decoder_k5 = &decoder_outputs[21]; + let decoder_v5 = &decoder_outputs[22]; + + // decode each token for each batch + for (i, logit) in logits.axis_iter(Axis(0)).enumerate() { + if !finished[i] { + let token_id = logits_sampler.decode( + &logit + .slice(s![-1, ..]) + .into_owned() + .into_raw_vec_and_offset() + .0, + )?; // + generated_tokens[i].push(token_id); + + // update last_tokens + last_tokens[i] = token_id as f32; + + if token_id == 2 { + finished[i] = true; + } + } + } + + // all finished? + if finished.iter().all(|&x| x) { + break; + } + + // next input text embedding + let next_tokens = X::from(last_tokens.clone()).insert_axis(1)?; + + // decode + let inputs_embeds = &self.text_embed.run(Xs::from(next_tokens))?[0].clone(); + let use_cache = X::ones(&[1]); + decoder_outputs = self.decoder_merged.run(Xs::from(vec![ + attention_mask.clone(), + last_hidden_state.clone(), + inputs_embeds.clone(), + decoder_k0.clone(), + decoder_v0.clone(), + encoder_k0.clone(), + encoder_v0.clone(), + decoder_k1.clone(), + decoder_v1.clone(), + encoder_k1.clone(), + encoder_v1.clone(), + decoder_k2.clone(), + decoder_v2.clone(), + encoder_k2.clone(), + encoder_v2.clone(), + decoder_k3.clone(), + decoder_v3.clone(), + encoder_k3.clone(), + encoder_v3.clone(), + decoder_k4.clone(), + decoder_v4.clone(), + encoder_k4.clone(), + encoder_v4.clone(), + decoder_k5.clone(), + decoder_v5.clone(), + encoder_k5.clone(), + encoder_v5.clone(), + use_cache, + ]))?; + } + + // batch decode + let texts = match self.tokenizer.decode_batch( + &generated_tokens + .iter() + .map(|tokens| tokens.as_slice()) + .collect::>(), + false, + ) { + Err(err) => anyhow::bail!("{:?}", err), + Ok(xs) => xs, + }; + + Ok(texts) + } + + pub fn encode_prompt(&self, task: &Task) -> Result { + let prompt = task.prompt_for_florence2()?; + let encodings = match self.tokenizer.encode(prompt, true) { + Err(err) => anyhow::bail!("{}", err), + Ok(x) => x, + }; + let ids: Vec = encodings.get_ids().iter().map(|x| *x as f32).collect(); + + Ok(X::from(ids)) + } + + fn process_polygons( + elems: &[String], + quantizer: &Quantizer, + image_width: usize, + image_height: usize, + ) -> Vec> { + elems + .par_chunks(2) + .map(|chunk| { + let coord: Vec<_> = chunk.iter().map(|s| s.parse::().unwrap()).collect(); + quantizer.dequantize(&coord, (image_width, image_height)) + }) + .collect() + } + + fn process_bboxes( + elems: &[String], + quantizer: &Quantizer, + image_width: usize, + image_height: usize, + class_name: Option<(&str, usize)>, + ) -> Vec { + elems + .par_chunks(4) + .enumerate() + .map(|(i, chunk)| { + let bbox: Vec<_> = chunk.iter().map(|s| s.parse::().unwrap()).collect(); + let dequantized_bbox = quantizer.dequantize(&bbox, (image_width, image_height)); + + let mut bbox = Bbox::default().with_xyxy( + dequantized_bbox[0].max(0.0f32).min(image_width as f32), + dequantized_bbox[1].max(0.0f32).min(image_height as f32), + dequantized_bbox[2], + dequantized_bbox[3], + ); + if let Some((class_name, i)) = class_name { + bbox = bbox.with_name(class_name).with_id(i as _); + } else { + bbox = bbox.with_id(i as _); + } + + bbox + }) + .collect() + } + + fn loc_parse(hay: &str) -> Result>> { + let pattern = r"(?i)(\d+)>)|(?[^<]+)"; + let re = regex::Regex::new(pattern)?; + let mut ys: Vec> = Vec::new(); + let mut y = Vec::new(); + + for cap in re.captures_iter(hay) { + if let Some(loc) = cap.name("coord") { + y.push(loc.as_str().to_string()); + } else if let Some(text) = cap.name("name") { + if !text.as_str().is_empty() { + if !y.is_empty() { + ys.push(y); + y = Vec::new(); + } + y.push(text.as_str().to_string()); + } + } + } + if !y.is_empty() { + ys.push(y); + } + Ok(ys) + } + + pub fn batch(&self) -> usize { + self.batch.opt as usize + } +} diff --git a/src/models/mod.rs b/src/models/mod.rs index 0d13182..28ecb53 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -5,6 +5,7 @@ mod clip; mod db; mod depth_anything; mod dinov2; +mod florence2; mod grounding_dino; mod modnet; mod rtmo; @@ -20,6 +21,7 @@ pub use clip::Clip; pub use db::DB; pub use depth_anything::DepthAnything; pub use dinov2::Dinov2; +pub use florence2::Florence2; pub use grounding_dino::GroundingDINO; pub use modnet::MODNet; pub use rtmo::RTMO; diff --git a/src/utils/mod.rs b/src/utils/mod.rs index dc4aef1..543d13c 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -5,9 +5,11 @@ use rand::{distributions::Alphanumeric, thread_rng, Rng}; pub mod colormap256; pub mod names; +mod quantizer; pub use colormap256::*; pub use names::*; +pub use quantizer::Quantizer; pub(crate) const CHECK_MARK: &str = "✅"; pub(crate) const CROSS_MARK: &str = "❌"; @@ -28,6 +30,19 @@ pub(crate) const PROGRESS_BAR_STYLE_CYAN: &str = "{prefix:.cyan.bold} {msg} {human_pos}/{human_len} |{bar}| {elapsed_precise}"; pub(crate) const PROGRESS_BAR_STYLE_GREEN: &str = "{prefix:.green.bold} {msg} {human_pos}/{human_len} |{bar}| {elapsed_precise}"; +pub(crate) const PROGRESS_BAR_STYLE_CYAN_2: &str = + "{prefix:.cyan.bold} {human_pos}/{human_len} |{bar}| {msg}"; +pub(crate) const PROGRESS_BAR_STYLE_CYAN_3: &str = + "{prefix:.cyan.bold} |{bar}| {human_pos}/{human_len} {msg}"; +pub(crate) const PROGRESS_BAR_STYLE_GREEN_2: &str = + "{prefix:.green.bold} {human_pos}/{human_len} |{bar}| {elapsed_precise}"; +pub(crate) const PROGRESS_BAR_STYLE_FINISH: &str = + "{prefix:.green.bold} {msg} for {human_len} iterations in {elapsed}"; +pub(crate) const PROGRESS_BAR_STYLE_FINISH_2: &str = + "{prefix:.green.bold} {msg} x{human_len} in {elapsed}"; +pub(crate) const PROGRESS_BAR_STYLE_FINISH_3: &str = + "{prefix:.green.bold} {msg} ({binary_total_bytes}) in {elapsed}"; +pub(crate) const PROGRESS_BAR_STYLE_FINISH_4: &str = "{prefix:.green.bold} {msg} in {elapsed}"; pub fn human_bytes(size: f64) -> String { let units = ["B", "KB", "MB", "GB", "TB", "PB", "EB"]; diff --git a/src/utils/quantizer.rs b/src/utils/quantizer.rs new file mode 100644 index 0000000..1a3a6ac --- /dev/null +++ b/src/utils/quantizer.rs @@ -0,0 +1,82 @@ +#[derive(Debug)] +pub struct Quantizer { + bins: (usize, usize), +} + +impl Default for Quantizer { + fn default() -> Self { + Self { bins: (1000, 1000) } + } +} + +impl Quantizer { + pub fn new(bins: (usize, usize)) -> Self { + Quantizer { bins } + } + + fn quantize_value(&self, val: f32, bin_size: f64, max_bin: usize) -> usize { + ((val as f64 / bin_size).floor() as usize).clamp(0, max_bin - 1) + } + + fn dequantize_value(&self, val: usize, bin_size: f64) -> f32 { + ((val as f64 + 0.5) * bin_size) as f32 + } + + fn quantize_internal(&self, input: &[f32], size: (usize, usize)) -> Vec { + let (bins_w, bins_h) = self.bins; + let (size_w, size_h) = size; + + let size_per_bin_w = size_w as f64 / bins_w as f64; + let size_per_bin_h = size_h as f64 / bins_h as f64; + + match input.len() { + 4 => vec![ + self.quantize_value(input[0], size_per_bin_w, bins_w), + self.quantize_value(input[1], size_per_bin_h, bins_h), + self.quantize_value(input[2], size_per_bin_w, bins_w), + self.quantize_value(input[3], size_per_bin_h, bins_h), + ], + 2 => vec![ + self.quantize_value(input[0], size_per_bin_w, bins_w), + self.quantize_value(input[1], size_per_bin_h, bins_h), + ], + _ => panic!( + "Error: Unsupported input length: {} in Quantizer.", + input.len() + ), + } + } + + fn dequantize_internal(&self, input: &[usize], size: (usize, usize)) -> Vec { + let (bins_w, bins_h) = self.bins; + let (size_w, size_h) = size; + + let size_per_bin_w = size_w as f64 / bins_w as f64; + let size_per_bin_h = size_h as f64 / bins_h as f64; + + match input.len() { + 4 => vec![ + self.dequantize_value(input[0], size_per_bin_w), + self.dequantize_value(input[1], size_per_bin_h), + self.dequantize_value(input[2], size_per_bin_w), + self.dequantize_value(input[3], size_per_bin_h), + ], + 2 => vec![ + self.dequantize_value(input[0], size_per_bin_w), + self.dequantize_value(input[1], size_per_bin_h), + ], + _ => panic!( + "Error: Unsupported input length: {} in Quantizer.", + input.len() + ), + } + } + + pub fn quantize(&self, input: &[f32], size: (usize, usize)) -> Vec { + self.quantize_internal(input, size) + } + + pub fn dequantize(&self, input: &[usize], size: (usize, usize)) -> Vec { + self.dequantize_internal(input, size) + } +} diff --git a/src/ys/polygon.rs b/src/ys/polygon.rs index 48fa6b9..be20b32 100644 --- a/src/ys/polygon.rs +++ b/src/ys/polygon.rs @@ -49,6 +49,16 @@ impl Polygon { self } + pub fn with_points(mut self, points: &[Vec]) -> Self { + // exterior + let v = points + .iter() + .map(|p| coord! { x: p[0] as f64, y: p[1] as f64}) + .collect::>(); + self.polygon = geo::Polygon::new(LineString::from(v), vec![]); + self + } + pub fn with_polygon(mut self, x: geo::Polygon) -> Self { self.polygon = x; self