From f0fd4936e894ddbbcdc70d65df3e2b8961e275bd Mon Sep 17 00:00:00 2001
From: Jamjamjon <51357717+jamjamjon@users.noreply.github.com>
Date: Sat, 21 Sep 2024 20:23:42 +0800
Subject: [PATCH] Add florence2 model
* Add florence2-base model for all tasks
* Update annotator.rs
---
Cargo.toml | 3 +-
README.md | 5 +-
examples/blip/main.rs | 2 +-
examples/dataloader/main.rs | 16 +-
examples/florence2/main.rs | 252 +++++++++++
src/core/annotator.rs | 43 +-
src/core/dataloader.rs | 15 +-
src/core/hub.rs | 84 ++--
src/core/min_opt_max.rs | 10 +
src/core/ops.rs | 10 +-
src/core/options.rs | 848 ++++++++++++++++++++++++++++++++++-
src/core/ort_engine.rs | 179 +++++++-
src/core/task.rs | 185 +++++++-
src/core/tokenizer_stream.rs | 5 +-
src/core/vision.rs | 2 +-
src/core/x.rs | 5 +
src/lib.rs | 2 +-
src/models/florence2.rs | 459 +++++++++++++++++++
src/models/mod.rs | 2 +
src/utils/mod.rs | 15 +
src/utils/quantizer.rs | 82 ++++
src/ys/polygon.rs | 10 +
22 files changed, 2124 insertions(+), 110 deletions(-)
create mode 100644 examples/florence2/main.rs
create mode 100644 src/models/florence2.rs
create mode 100644 src/utils/quantizer.rs
diff --git a/Cargo.toml b/Cargo.toml
index 799bd06..5beb404 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "usls"
-version = "0.0.14"
+version = "0.0.15"
edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"
@@ -22,7 +22,6 @@ dirs = { version = "5.0.1" }
ureq = { version = "2.9.1", default-features = true, features = [
"socks-proxy",
] }
-walkdir = { version = "2.5.0" } # TODO: remove
tokenizers = { version = "0.15.2" }
rayon = "1.10.0"
indicatif = "0.17.8"
diff --git a/README.md b/README.md
index a3ba257..9090e04 100644
--- a/README.md
+++ b/README.md
@@ -37,7 +37,7 @@
- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10)
- **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569)
-- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World)
+- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242)
Click to expand Supported Models
@@ -71,6 +71,9 @@
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ |
| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | |
| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Body Part Segmentation | [demo](examples/sapiens) | ✅ | ✅ | | |
+| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | | |
+
+
diff --git a/examples/blip/main.rs b/examples/blip/main.rs
index 656d5c2..b0cc929 100644
--- a/examples/blip/main.rs
+++ b/examples/blip/main.rs
@@ -10,7 +10,7 @@ fn main() -> Result<(), Box> {
// textual
let options_textual = Options::default()
.with_model("blip/textual-base.onnx")?
- // .with_tokenizer("blip/tokenizer.json")?
+ .with_tokenizer("blip/tokenizer.json")?
.with_i00((1, 1, 4).into()) // input_id: batch
.with_i01((1, 1, 4).into()) // input_id: seq_len
.with_i10((1, 1, 4).into()) // attention_mask: batch
diff --git a/examples/dataloader/main.rs b/examples/dataloader/main.rs
index a0b62db..40484f2 100644
--- a/examples/dataloader/main.rs
+++ b/examples/dataloader/main.rs
@@ -18,26 +18,24 @@ fn main() -> anyhow::Result<()> {
// build dataloader
let dl = DataLoader::new(
+ // "images/bus.jpg", // remote image
+ // "../images", // image folder
+ // "../demo.mp4", // local video
+ // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // remote video
+ // "rtsp://admin:xyz@192.168.2.217:554/h265/ch1/", // rtsp h264 stream
"./assets/bus.jpg", // local image
- // "images/bus.jpg", // remote image
- // "../images", // image folder
- // "../demo.mp4", // local video
- // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // remote video
- // "rtsp://admin:xyz@192.168.2.217:554/h265/ch1/", // rtsp h264 stream
)?
.with_batch(1)
- .with_progress_bar(true)
- .with_bound(100)
.build()?;
- // // build annotator
+ // build annotator
let annotator = Annotator::new()
.with_bboxes_thickness(4)
.with_saveout("YOLO-DataLoader");
// run
for (xs, _) in dl {
- // std::thread::sleep(std::time::Duration::from_millis(1000));
+ // std::thread::sleep(std::time::Duration::from_millis(100));
let ys = model.forward(&xs, false)?;
annotator.annotate(&xs, &ys);
}
diff --git a/examples/florence2/main.rs b/examples/florence2/main.rs
new file mode 100644
index 0000000..546ebb9
--- /dev/null
+++ b/examples/florence2/main.rs
@@ -0,0 +1,252 @@
+use usls::{models::Florence2, Annotator, DataLoader, Options, Task};
+
+fn main() -> Result<(), Box> {
+ // vision encoder
+ let options_vision_encoder = Options::default()
+ .with_model("florence2/base-vision-encoder.onnx")?
+ .with_i00((1, 2, 4).into())
+ .with_i02((512, 768, 800).into())
+ .with_i03((512, 768, 800).into())
+ .with_profile(false)
+ .with_cuda(0);
+
+ // text embed
+ let options_text_embed = Options::default()
+ .with_model("florence2/base-embed-tokens.onnx")?
+ .with_i00((1, 2, 4).into())
+ .with_i01((1, 2, 20).into()) // seq_length
+ .with_tokenizer("florence2/tokenizer.json")?
+ .with_profile(false);
+
+ // transformer encoder
+ let options_encoder = Options::default()
+ .with_model("florence2/base-encoder.onnx")?
+ .with_i00((1, 2, 4).into())
+ .with_i01((1, 2, 300).into()) // encoder_sequence_length
+ .with_i10((1, 2, 4).into())
+ .with_i11((1, 2, 300).into()) // encoder_sequence_length
+ .with_profile(false);
+
+ // transformer decoder
+ let options_decoder = Options::default()
+ .with_model("florence2/base-decoder.onnx")?
+ .with_i00((1, 2, 4).into())
+ .with_i01((1, 2, 300).into()) // encoder_sequence_length
+ .with_i10((1, 2, 4).into())
+ .with_i11((1, 2, 300).into()) // encoder_sequence_length
+ .with_i20((1, 2, 4).into())
+ .with_i21((1, 2, 300).into()) // encoder_sequence_length
+ .with_profile(false);
+
+ // transformer decoder merged
+ let options_decoder_merged = Options::default()
+ .with_model("florence2/base-decoder-merged.onnx")?
+ // encoder_attention_mask
+ .with_i00((1, 2, 4).into())
+ .with_i01((1, 2, 300).into()) // encoder_sequence_length
+ // encoder_hidden_states
+ .with_i10((1, 2, 4).into())
+ .with_i11((1, 2, 300).into()) // encoder_sequence_length
+ // inputs_embeds
+ .with_i20((1, 2, 4).into())
+ .with_i21((1, 2, 300).into()) // encoder_sequence_length
+ // past_key_values.0.decoder.key
+ .with_i30((1, 2, 4).into())
+ .with_i32_((1, 2, 1).into())
+ // past_key_values.0.decoder.value
+ .with_i40((1, 2, 4).into())
+ .with_i42((1, 2, 1).into())
+ // past_key_values.0.encoder.key
+ .with_i50((1, 2, 4).into())
+ .with_i52((1, 2, 1).into())
+ // past_key_values.0.decoder.value
+ .with_i60((1, 2, 4).into())
+ .with_i62((1, 2, 1).into())
+ // past_key_values.1.decoder.key
+ .with_i70((1, 2, 4).into())
+ .with_i72((1, 2, 1).into())
+ // past_key_values.1.decoder.value
+ .with_i80((1, 2, 4).into())
+ .with_i82((1, 2, 1).into())
+ // past_key_values.1.encoder.key
+ .with_i90((1, 2, 4).into())
+ .with_i92((1, 2, 1).into())
+ // past_key_values.1.decoder.value
+ .with_i100((1, 2, 4).into())
+ .with_i102((1, 2, 1).into())
+ // past_key_values.2.decoder.key
+ .with_i110((1, 2, 4).into())
+ .with_i112((1, 2, 1).into())
+ // past_key_values.2.decoder.value
+ .with_i120((1, 2, 4).into())
+ .with_i122((1, 2, 1).into())
+ // past_key_values.2.encoder.key
+ .with_i130((1, 2, 4).into())
+ .with_i132((1, 2, 1).into())
+ // past_key_values.2.decoder.value
+ .with_i140((1, 2, 4).into())
+ .with_i142((1, 2, 1).into())
+ // past_key_values.3.decoder.key
+ .with_i150((1, 2, 4).into())
+ .with_i152((1, 2, 1).into())
+ // past_key_values.3.decoder.value
+ .with_i160((1, 2, 4).into())
+ .with_i162((1, 2, 1).into())
+ // past_key_values.3.encoder.key
+ .with_i170((1, 2, 4).into())
+ .with_i172((1, 2, 1).into())
+ // past_key_values.3.decoder.value
+ .with_i180((1, 2, 4).into())
+ .with_i182((1, 2, 1).into())
+ // past_key_values.4.decoder.key
+ .with_i190((1, 2, 4).into())
+ .with_i192((1, 2, 1).into())
+ // past_key_values.4.decoder.value
+ .with_i200((1, 2, 4).into())
+ .with_i202((1, 2, 1).into())
+ // past_key_values.4.encoder.key
+ .with_i210((1, 2, 4).into())
+ .with_i212((1, 2, 1).into())
+ // past_key_values.4.decoder.value
+ .with_i220((1, 2, 4).into())
+ .with_i222((1, 2, 1).into())
+ // past_key_values.5.decoder.key
+ .with_i230((1, 2, 4).into())
+ .with_i232((1, 2, 1).into())
+ // past_key_values.5.decoder.value
+ .with_i240((1, 2, 4).into())
+ .with_i242((1, 2, 1).into())
+ // past_key_values.5.encoder.key
+ .with_i250((1, 2, 4).into())
+ .with_i252((1, 2, 1).into())
+ // past_key_values.5.decoder.value
+ .with_i260((1, 2, 4).into())
+ .with_i262((1, 2, 1).into())
+ //use_cache_branch
+ .with_i270((1, 2, 1).into())
+ .with_profile(false);
+
+ // build model
+ let mut model = Florence2::new(
+ options_vision_encoder,
+ options_text_embed,
+ options_encoder,
+ options_decoder,
+ options_decoder_merged,
+ )?;
+
+ // load images
+ let xs = [
+ // DataLoader::try_read("florence2/car.jpg")?, // for testing region-related tasks
+ DataLoader::try_read("florence2/car.jpg")?,
+ // DataLoader::try_read("images/db.png")?,
+ DataLoader::try_read("assets/bus.jpg")?,
+ ];
+
+ // region-related tasks
+ let quantizer = usls::Quantizer::default();
+ // let coords = [449., 270., 556., 372.]; // wheel
+ let coords = [31., 156., 581., 373.]; // car
+ let (width_car, height_car) = (xs[0].width(), xs[0].height());
+ let quantized_coords = quantizer.quantize(&coords, (width_car as _, height_car as _));
+
+ // run with tasks
+ let ys = model.run_with_tasks(
+ &xs,
+ &[
+ // w/ inputs
+ Task::Caption(0),
+ Task::Caption(1),
+ Task::Caption(2),
+ Task::Ocr,
+ Task::OcrWithRegion,
+ Task::RegionProposal,
+ Task::ObjectDetection,
+ Task::DenseRegionCaption,
+ // w/o inputs
+ Task::OpenSetDetection("a vehicle".into()),
+ Task::CaptionToPhraseGrounding(
+ "A vehicle with two wheels parked in front of a building.".into(),
+ ),
+ Task::ReferringExpressionSegmentation("a vehicle".into()),
+ Task::RegionToSegmentation(
+ quantized_coords[0],
+ quantized_coords[1],
+ quantized_coords[2],
+ quantized_coords[3],
+ ),
+ Task::RegionToCategory(
+ quantized_coords[0],
+ quantized_coords[1],
+ quantized_coords[2],
+ quantized_coords[3],
+ ),
+ Task::RegionToDescription(
+ quantized_coords[0],
+ quantized_coords[1],
+ quantized_coords[2],
+ quantized_coords[3],
+ ),
+ ],
+ )?;
+
+ // annotator
+ let annotator = Annotator::new()
+ .without_bboxes_conf(true)
+ .with_bboxes_thickness(3)
+ .with_saveout_subs(&["Florence2"]);
+ for (task, ys_) in ys.iter() {
+ match task {
+ Task::Caption(_)
+ | Task::Ocr
+ | Task::RegionToCategory(..)
+ | Task::RegionToDescription(..) => {
+ println!("Task: {:?}\n{:?}\n", task, ys_)
+ }
+ Task::DenseRegionCaption => {
+ let annotator = annotator.clone().with_saveout("Dense-Region-Caption");
+ annotator.annotate(&xs, ys_);
+ }
+ Task::RegionProposal => {
+ let annotator = annotator
+ .clone()
+ .without_bboxes_name(false)
+ .with_saveout("Region-Proposal");
+
+ annotator.annotate(&xs, ys_);
+ }
+ Task::ObjectDetection => {
+ let annotator = annotator.clone().with_saveout("Object-Detection");
+ annotator.annotate(&xs, ys_);
+ }
+ Task::OpenSetDetection(_) => {
+ let annotator = annotator.clone().with_saveout("Open-Set-Detection");
+ annotator.annotate(&xs, ys_);
+ }
+ Task::CaptionToPhraseGrounding(_) => {
+ let annotator = annotator
+ .clone()
+ .with_saveout("Caption-To-Phrase-Grounding");
+ annotator.annotate(&xs, ys_);
+ }
+ Task::ReferringExpressionSegmentation(_) => {
+ let annotator = annotator
+ .clone()
+ .with_saveout("Referring-Expression-Segmentation");
+ annotator.annotate(&xs, ys_);
+ }
+ Task::RegionToSegmentation(..) => {
+ let annotator = annotator.clone().with_saveout("Region-To-Segmentation");
+ annotator.annotate(&xs, ys_);
+ }
+ Task::OcrWithRegion => {
+ let annotator = annotator.clone().with_saveout("Ocr-With-Region");
+ annotator.annotate(&xs, ys_);
+ }
+
+ _ => (),
+ }
+ }
+
+ Ok(())
+}
diff --git a/src/core/annotator.rs b/src/core/annotator.rs
index 75a667c..af8c14d 100644
--- a/src/core/annotator.rs
+++ b/src/core/annotator.rs
@@ -2,19 +2,21 @@ use crate::{
colormap256, string_now, Bbox, Dir, Hub, Keypoint, Mask, Mbr, Polygon, Prob, CHECK_MARK,
CROSS_MARK, Y,
};
-use ab_glyph::{FontVec, PxScale};
+use ab_glyph::{FontArc, PxScale};
use anyhow::Result;
use image::{DynamicImage, GenericImage, Rgba, RgbaImage};
use imageproc::map::map_colors;
/// Annotator for struct `Y`
-// #[derive(Debug)]
+#[derive(Clone)]
pub struct Annotator {
- font: FontVec,
+ // TODO: Add lifetime
+ font: FontArc,
_scale: f32, // Cope with ab_glyph & imageproc=0.24.0
scale_dy: f32,
saveout_base: String,
saveout: Option,
+ saveout_subs: Vec,
decimal_places: usize,
// About mbrs
@@ -72,6 +74,7 @@ impl Default for Annotator {
scale_dy: 28.,
polygons_alpha: 179,
saveout: None,
+ saveout_subs: vec![],
saveout_base: String::from("runs"),
decimal_places: 4,
without_bboxes: false,
@@ -323,6 +326,11 @@ impl Annotator {
self
}
+ pub fn with_saveout_subs(mut self, xs: &[&str]) -> Self {
+ self.saveout_subs = xs.iter().map(|x| x.to_string()).collect::>();
+ self
+ }
+
pub fn with_font(mut self, path: &str) -> Result {
self.font = Self::load_font(Some(path))?;
Ok(self)
@@ -330,11 +338,23 @@ impl Annotator {
/// Create folders for saving annotated results. e.g., `./runs/xxx`
pub fn saveout(&self) -> Result {
- let subs = match &self.saveout {
- Some(x) => vec![self.saveout_base.as_str(), x.as_str()],
- None => vec![self.saveout_base.as_str()],
- };
+ let mut subs = vec![self.saveout_base.as_str()];
+ if let Some(saveout) = &self.saveout {
+ // add subs
+ if !self.saveout_subs.is_empty() {
+ let xs = self
+ .saveout_subs
+ .iter()
+ .map(|x| x.as_str())
+ .collect::>();
+ subs.extend(xs);
+ }
+
+ // add filename
+ subs.push(saveout);
+ }
+ // mkdir even no filename specified
Dir::Currnet.raw_path_with_subs(&subs)
}
@@ -345,7 +365,7 @@ impl Annotator {
/// Plot images and return plotted images(RGBA8)
pub fn plot(&self, imgs: &[DynamicImage], ys: &[Y]) -> Result> {
- let span = tracing::span!(tracing::Level::INFO, "YOLO-new");
+ let span = tracing::span!(tracing::Level::INFO, "Annotator-plot");
let _guard = span.enter();
let mut vs: Vec = Vec::new();
@@ -719,7 +739,7 @@ impl Annotator {
let top = if y > text_h as f32 {
(y.round() as u32 - text_h) as i32
} else {
- (text_h - y.round() as u32) as i32
+ 0
};
let mut left = x as i32;
if left + text_w as i32 > img.width() as i32 {
@@ -749,13 +769,13 @@ impl Annotator {
}
/// Load custom font
- fn load_font(path: Option<&str>) -> Result {
+ fn load_font(path: Option<&str>) -> Result {
let path_font = match path {
None => Hub::new()?.fetch("fonts/Arial.ttf")?.commit()?,
Some(p) => p.into(),
};
let buffer = std::fs::read(path_font)?;
- Ok(FontVec::try_from_vec(buffer.to_owned())?)
+ Ok(FontArc::try_from_vec(buffer.to_owned())?)
}
/// Pick color from pallette
@@ -765,6 +785,7 @@ impl Annotator {
/// Color pallette
fn color_palette() -> [(u8, u8, u8, u8); 20] {
+ // TODO: more colors
[
(0, 255, 127, 255), // spring green
(255, 105, 180, 255), // hot pink
diff --git a/src/core/dataloader.rs b/src/core/dataloader.rs
index c4dfc51..502ce92 100644
--- a/src/core/dataloader.rs
+++ b/src/core/dataloader.rs
@@ -35,7 +35,7 @@ impl Iterator for DataLoaderIterator {
None => {
progress_bar.set_prefix(" Iterated");
progress_bar.set_style(
- indicatif::ProgressStyle::with_template(crate::PROGRESS_BAR_STYLE_GREEN)
+ indicatif::ProgressStyle::with_template(crate::PROGRESS_BAR_STYLE_FINISH_2)
.unwrap(),
);
progress_bar.finish();
@@ -56,7 +56,7 @@ impl IntoIterator for DataLoader {
self.nf / self.batch_size as u64,
" Iterating",
Some(&format!("{:?}", self.media_type)),
- crate::PROGRESS_BAR_STYLE_CYAN,
+ crate::PROGRESS_BAR_STYLE_CYAN_2,
)
.ok()
} else {
@@ -365,7 +365,12 @@ impl DataLoader {
let saveout = Dir::Currnet
.raw_path_with_subs(subs)?
.join(format!("{}.mp4", string_now("-")));
- let pb = build_progress_bar(paths.len() as u64, " Converting", saveout.to_str(),"{prefix:.cyan.bold} {msg} |{bar}| ({percent_precise}%, {human_pos}/{human_len}, {per_sec})")?;
+ let pb = build_progress_bar(
+ paths.len() as u64,
+ " Converting",
+ Some(&format!("{:?}", MediaType::Video(Location::Local))),
+ crate::PROGRESS_BAR_STYLE_CYAN_2,
+ )?;
// loop
for path in paths {
@@ -397,10 +402,10 @@ impl DataLoader {
}
// update
- pb.set_prefix(" Downloaded");
pb.set_prefix(" Converted");
+ pb.set_message(saveout.to_str().unwrap_or_default().to_string());
pb.set_style(ProgressStyle::with_template(
- "{prefix:.green.bold} {msg} in {elapsed}",
+ crate::PROGRESS_BAR_STYLE_FINISH_4,
)?);
pb.finish();
diff --git a/src/core/hub.rs b/src/core/hub.rs
index f7b49e5..c3f3e90 100644
--- a/src/core/hub.rs
+++ b/src/core/hub.rs
@@ -152,65 +152,59 @@ impl Hub {
match p.exists() {
true => self.path = p,
false => {
- // check local cache 1st
- let p_cache = self.cache.with_file_name(s);
- if p_cache.exists() {
- self.path = p_cache;
- } else {
- // check remote list then
- match s.split_once('/') {
- Some((tag, file_name)) => {
- // Extract tag and file from input string
- self.tag = Some(tag.to_string());
- self.file_name = Some(file_name.to_string());
-
- // Check if releases are already loaded in memory
- if self.releases.is_none() {
- self.releases = Some(self.connect_remote()?);
+ // check remote
+ match s.split_once('/') {
+ Some((tag, file_name)) => {
+ // Extract tag and file from input string
+ self.tag = Some(tag.to_string());
+ self.file_name = Some(file_name.to_string());
+
+ // Check if releases are already loaded in memory
+ if self.releases.is_none() {
+ self.releases = Some(self.connect_remote()?);
+ }
+
+ if let Some(releases) = &self.releases {
+ // Validate the tag
+ let tags: Vec<&str> =
+ releases.iter().map(|x| x.tag_name.as_str()).collect();
+ if !tags.contains(&tag) {
+ anyhow::bail!(
+ "Hub tag '{}' not found in releases. Available tags: {:?}",
+ tag,
+ tags
+ );
}
- if let Some(releases) = &self.releases {
- // Validate the tag
- let tags: Vec<&str> =
- releases.iter().map(|x| x.tag_name.as_str()).collect();
- if !tags.contains(&tag) {
+ // Validate the file
+ if let Some(release) = releases.iter().find(|r| r.tag_name == tag) {
+ let files: Vec<&str> =
+ release.assets.iter().map(|x| x.name.as_str()).collect();
+ if !files.contains(&file_name) {
anyhow::bail!(
- "Hub tag '{}' not found in releases. Available tags: {:?}",
- tag,
- tags
- );
- }
-
- // Validate the file
- if let Some(release) = releases.iter().find(|r| r.tag_name == tag) {
- let files: Vec<&str> =
- release.assets.iter().map(|x| x.name.as_str()).collect();
- if !files.contains(&file_name) {
- anyhow::bail!(
"Hub file '{}' not found in tag '{}'. Available files: {:?}",
file_name,
tag,
files
);
- } else {
- for f_ in release.assets.iter() {
- if f_.name.as_str() == file_name {
- self.url = Some(f_.browser_download_url.clone());
- self.file_size = Some(f_.size);
-
- break;
- }
+ } else {
+ for f_ in release.assets.iter() {
+ if f_.name.as_str() == file_name {
+ self.url = Some(f_.browser_download_url.clone());
+ self.file_size = Some(f_.size);
+
+ break;
}
}
}
- self.path = self.to.path_with_subs(&[tag])?.join(file_name);
}
+ self.path = self.to.path_with_subs(&[tag])?.join(file_name);
}
- _ => anyhow::bail!(
+ }
+ _ => anyhow::bail!(
"Download failed due to invalid format. Expected: /, got: {}",
s
),
- }
}
}
}
@@ -286,7 +280,7 @@ impl Hub {
}
pub fn connect_remote(&mut self) -> Result> {
- let span = tracing::span!(tracing::Level::INFO, "OrtEngine-run");
+ let span = tracing::span!(tracing::Level::INFO, "Hub-connect_remote");
let _guard = span.enter();
let should_download = if !self.cache.exists() {
@@ -416,7 +410,7 @@ impl Hub {
// update
pb.set_prefix(" Downloaded");
pb.set_style(ProgressStyle::with_template(
- "{prefix:.green.bold} {msg} ({binary_total_bytes}) in {elapsed}",
+ crate::PROGRESS_BAR_STYLE_FINISH_3,
)?);
pb.finish();
diff --git a/src/core/min_opt_max.rs b/src/core/min_opt_max.rs
index 71c4970..803d4d2 100644
--- a/src/core/min_opt_max.rs
+++ b/src/core/min_opt_max.rs
@@ -50,4 +50,14 @@ impl MinOptMax {
max: opt,
}
}
+
+ pub fn update(&mut self, opt: isize) {
+ self.opt = opt;
+ if self.min > opt {
+ self.min = opt;
+ }
+ if self.max < opt {
+ self.max = opt;
+ }
+ }
}
diff --git a/src/core/ops.rs b/src/core/ops.rs
index 10d3a74..7b60fb3 100644
--- a/src/core/ops.rs
+++ b/src/core/ops.rs
@@ -7,7 +7,7 @@ use fast_image_resize::{
FilterType, ResizeAlg, ResizeOptions, Resizer,
};
use image::{DynamicImage, GenericImageView};
-use ndarray::{s, Array, Array3, Axis, IntoDimension, IxDyn};
+use ndarray::{concatenate, s, Array, Array3, Axis, IntoDimension, IxDyn};
use rayon::prelude::*;
pub enum Ops<'a> {
@@ -114,6 +114,14 @@ impl Ops<'_> {
Self::permute(x, &[0, 2, 3, 1])
}
+ pub fn concatenate(
+ x: &Array,
+ y: &Array,
+ d: usize,
+ ) -> Result> {
+ Ok(concatenate(Axis(d), &[x.view(), y.view()])?)
+ }
+
pub fn insert_axis(x: Array, d: usize) -> Result> {
if x.shape().len() < d {
anyhow::bail!(
diff --git a/src/core/options.rs b/src/core/options.rs
index 2c6cc41..3bd8691 100644
--- a/src/core/options.rs
+++ b/src/core/options.rs
@@ -4,13 +4,14 @@ use anyhow::Result;
use crate::{
models::{SamKind, SapiensTask, YOLOPreds, YOLOTask, YOLOVersion},
- Device, Hub, MinOptMax,
+ Device, Hub, MinOptMax, Task,
};
/// Options for building models
#[derive(Debug, Clone)]
pub struct Options {
pub onnx_path: String,
+ pub task: Task,
pub device: Device,
pub profile: bool,
pub num_dry_run: usize,
@@ -62,6 +63,126 @@ pub struct Options {
pub i73: Option,
pub i74: Option,
pub i75: Option,
+ pub i80: Option,
+ pub i81: Option,
+ pub i82: Option,
+ pub i83: Option,
+ pub i84: Option,
+ pub i85: Option,
+ pub i90: Option,
+ pub i91: Option,
+ pub i92: Option,
+ pub i93: Option,
+ pub i94: Option,
+ pub i95: Option,
+ pub i100: Option,
+ pub i101: Option,
+ pub i102: Option,
+ pub i103: Option,
+ pub i104: Option,
+ pub i105: Option,
+ pub i110: Option,
+ pub i111: Option,
+ pub i112: Option,
+ pub i113: Option,
+ pub i114: Option,
+ pub i115: Option,
+ pub i120: Option,
+ pub i121: Option,
+ pub i122: Option,
+ pub i123: Option,
+ pub i124: Option,
+ pub i125: Option,
+ pub i130: Option,
+ pub i131: Option,
+ pub i132: Option,
+ pub i133: Option,
+ pub i134: Option,
+ pub i135: Option,
+ pub i140: Option,
+ pub i141: Option,
+ pub i142: Option,
+ pub i143: Option,
+ pub i144: Option,
+ pub i145: Option,
+ pub i150: Option,
+ pub i151: Option,
+ pub i152: Option,
+ pub i153: Option,
+ pub i154: Option,
+ pub i155: Option,
+ pub i160: Option,
+ pub i161: Option,
+ pub i162: Option,
+ pub i163: Option,
+ pub i164: Option,
+ pub i165: Option,
+ pub i170: Option,
+ pub i171: Option,
+ pub i172: Option,
+ pub i173: Option,
+ pub i174: Option,
+ pub i175: Option,
+ pub i180: Option,
+ pub i181: Option,
+ pub i182: Option,
+ pub i183: Option,
+ pub i184: Option,
+ pub i185: Option,
+ pub i190: Option,
+ pub i191: Option,
+ pub i192: Option,
+ pub i193: Option,
+ pub i194: Option,
+ pub i195: Option,
+ pub i200: Option,
+ pub i201: Option,
+ pub i202: Option,
+ pub i203: Option,
+ pub i204: Option,
+ pub i205: Option,
+ pub i210: Option,
+ pub i211: Option,
+ pub i212: Option,
+ pub i213: Option,
+ pub i214: Option,
+ pub i215: Option,
+ pub i220: Option,
+ pub i221: Option,
+ pub i222: Option,
+ pub i223: Option,
+ pub i224: Option,
+ pub i225: Option,
+ pub i230: Option,
+ pub i231: Option,
+ pub i232: Option,
+ pub i233: Option,
+ pub i234: Option,
+ pub i235: Option,
+ pub i240: Option,
+ pub i241: Option,
+ pub i242: Option,
+ pub i243: Option,
+ pub i244: Option,
+ pub i245: Option,
+ pub i250: Option,
+ pub i251: Option,
+ pub i252: Option,
+ pub i253: Option,
+ pub i254: Option,
+ pub i255: Option,
+ pub i260: Option,
+ pub i261: Option,
+ pub i262: Option,
+ pub i263: Option,
+ pub i264: Option,
+ pub i265: Option,
+ pub i270: Option,
+ pub i271: Option,
+ pub i272: Option,
+ pub i273: Option,
+ pub i274: Option,
+ pub i275: Option,
// trt related
pub trt_engine_cache_enable: bool,
pub trt_int8_enable: bool,
@@ -149,6 +270,126 @@ impl Default for Options {
i73: None,
i74: None,
i75: None,
+ i80: None,
+ i81: None,
+ i82: None,
+ i83: None,
+ i84: None,
+ i85: None,
+ i90: None,
+ i91: None,
+ i92: None,
+ i93: None,
+ i94: None,
+ i95: None,
+ i100: None,
+ i101: None,
+ i102: None,
+ i103: None,
+ i104: None,
+ i105: None,
+ i110: None,
+ i111: None,
+ i112: None,
+ i113: None,
+ i114: None,
+ i115: None,
+ i120: None,
+ i121: None,
+ i122: None,
+ i123: None,
+ i124: None,
+ i125: None,
+ i130: None,
+ i131: None,
+ i132: None,
+ i133: None,
+ i134: None,
+ i135: None,
+ i140: None,
+ i141: None,
+ i142: None,
+ i143: None,
+ i144: None,
+ i145: None,
+ i150: None,
+ i151: None,
+ i152: None,
+ i153: None,
+ i154: None,
+ i155: None,
+ i160: None,
+ i161: None,
+ i162: None,
+ i163: None,
+ i164: None,
+ i165: None,
+ i170: None,
+ i171: None,
+ i172: None,
+ i173: None,
+ i174: None,
+ i175: None,
+ i180: None,
+ i181: None,
+ i182: None,
+ i183: None,
+ i184: None,
+ i185: None,
+ i190: None,
+ i191: None,
+ i192: None,
+ i193: None,
+ i194: None,
+ i195: None,
+ i200: None,
+ i201: None,
+ i202: None,
+ i203: None,
+ i204: None,
+ i205: None,
+ i210: None,
+ i211: None,
+ i212: None,
+ i213: None,
+ i214: None,
+ i215: None,
+ i220: None,
+ i221: None,
+ i222: None,
+ i223: None,
+ i224: None,
+ i225: None,
+ i230: None,
+ i231: None,
+ i232: None,
+ i233: None,
+ i234: None,
+ i235: None,
+ i240: None,
+ i241: None,
+ i242: None,
+ i243: None,
+ i244: None,
+ i245: None,
+ i250: None,
+ i251: None,
+ i252: None,
+ i253: None,
+ i254: None,
+ i255: None,
+ i260: None,
+ i261: None,
+ i262: None,
+ i263: None,
+ i264: None,
+ i265: None,
+ i270: None,
+ i271: None,
+ i272: None,
+ i273: None,
+ i274: None,
+ i275: None,
trt_engine_cache_enable: true,
trt_int8_enable: false,
trt_fp16_enable: false,
@@ -176,6 +417,7 @@ impl Default for Options {
sam_kind: None,
use_low_res_mask: None,
sapiens_task: None,
+ task: Task::Untitled,
}
}
}
@@ -185,6 +427,11 @@ impl Options {
Default::default()
}
+ pub fn with_task(mut self, task: Task) -> Self {
+ self.task = task;
+ self
+ }
+
pub fn with_model(mut self, onnx_path: &str) -> Result {
self.onnx_path = Hub::new()?.fetch(onnx_path)?.commit()?;
Ok(self)
@@ -579,4 +826,603 @@ impl Options {
self.i75 = Some(x);
self
}
+
+ pub fn with_i80(mut self, x: MinOptMax) -> Self {
+ self.i80 = Some(x);
+ self
+ }
+
+ pub fn with_i81(mut self, x: MinOptMax) -> Self {
+ self.i81 = Some(x);
+ self
+ }
+
+ pub fn with_i82(mut self, x: MinOptMax) -> Self {
+ self.i82 = Some(x);
+ self
+ }
+
+ pub fn with_i83(mut self, x: MinOptMax) -> Self {
+ self.i83 = Some(x);
+ self
+ }
+
+ pub fn with_i84(mut self, x: MinOptMax) -> Self {
+ self.i84 = Some(x);
+ self
+ }
+
+ pub fn with_i85(mut self, x: MinOptMax) -> Self {
+ self.i85 = Some(x);
+ self
+ }
+
+ pub fn with_i90(mut self, x: MinOptMax) -> Self {
+ self.i90 = Some(x);
+ self
+ }
+
+ pub fn with_i91(mut self, x: MinOptMax) -> Self {
+ self.i91 = Some(x);
+ self
+ }
+
+ pub fn with_i92(mut self, x: MinOptMax) -> Self {
+ self.i92 = Some(x);
+ self
+ }
+
+ pub fn with_i93(mut self, x: MinOptMax) -> Self {
+ self.i93 = Some(x);
+ self
+ }
+
+ pub fn with_i94(mut self, x: MinOptMax) -> Self {
+ self.i94 = Some(x);
+ self
+ }
+
+ pub fn with_i95(mut self, x: MinOptMax) -> Self {
+ self.i95 = Some(x);
+ self
+ }
+
+ pub fn with_i100(mut self, x: MinOptMax) -> Self {
+ self.i100 = Some(x);
+ self
+ }
+
+ pub fn with_i101(mut self, x: MinOptMax) -> Self {
+ self.i101 = Some(x);
+ self
+ }
+
+ pub fn with_i102(mut self, x: MinOptMax) -> Self {
+ self.i102 = Some(x);
+ self
+ }
+
+ pub fn with_i103(mut self, x: MinOptMax) -> Self {
+ self.i103 = Some(x);
+ self
+ }
+
+ pub fn with_i104(mut self, x: MinOptMax) -> Self {
+ self.i104 = Some(x);
+ self
+ }
+
+ pub fn with_i105(mut self, x: MinOptMax) -> Self {
+ self.i105 = Some(x);
+ self
+ }
+
+ pub fn with_i110(mut self, x: MinOptMax) -> Self {
+ self.i110 = Some(x);
+ self
+ }
+
+ pub fn with_i111(mut self, x: MinOptMax) -> Self {
+ self.i111 = Some(x);
+ self
+ }
+
+ pub fn with_i112(mut self, x: MinOptMax) -> Self {
+ self.i112 = Some(x);
+ self
+ }
+
+ pub fn with_i113(mut self, x: MinOptMax) -> Self {
+ self.i113 = Some(x);
+ self
+ }
+
+ pub fn with_i114(mut self, x: MinOptMax) -> Self {
+ self.i114 = Some(x);
+ self
+ }
+
+ pub fn with_i115(mut self, x: MinOptMax) -> Self {
+ self.i115 = Some(x);
+ self
+ }
+
+ pub fn with_i120(mut self, x: MinOptMax) -> Self {
+ self.i120 = Some(x);
+ self
+ }
+
+ pub fn with_i121(mut self, x: MinOptMax) -> Self {
+ self.i121 = Some(x);
+ self
+ }
+
+ pub fn with_i122(mut self, x: MinOptMax) -> Self {
+ self.i122 = Some(x);
+ self
+ }
+
+ pub fn with_i123(mut self, x: MinOptMax) -> Self {
+ self.i123 = Some(x);
+ self
+ }
+
+ pub fn with_i124(mut self, x: MinOptMax) -> Self {
+ self.i124 = Some(x);
+ self
+ }
+
+ pub fn with_i125(mut self, x: MinOptMax) -> Self {
+ self.i125 = Some(x);
+ self
+ }
+
+ pub fn with_i130(mut self, x: MinOptMax) -> Self {
+ self.i130 = Some(x);
+ self
+ }
+
+ pub fn with_i131(mut self, x: MinOptMax) -> Self {
+ self.i131 = Some(x);
+ self
+ }
+
+ pub fn with_i132(mut self, x: MinOptMax) -> Self {
+ self.i132 = Some(x);
+ self
+ }
+
+ pub fn with_i133(mut self, x: MinOptMax) -> Self {
+ self.i133 = Some(x);
+ self
+ }
+
+ pub fn with_i134(mut self, x: MinOptMax) -> Self {
+ self.i134 = Some(x);
+ self
+ }
+
+ pub fn with_i135(mut self, x: MinOptMax) -> Self {
+ self.i135 = Some(x);
+ self
+ }
+
+ pub fn with_i140(mut self, x: MinOptMax) -> Self {
+ self.i140 = Some(x);
+ self
+ }
+
+ pub fn with_i141(mut self, x: MinOptMax) -> Self {
+ self.i141 = Some(x);
+ self
+ }
+
+ pub fn with_i142(mut self, x: MinOptMax) -> Self {
+ self.i142 = Some(x);
+ self
+ }
+
+ pub fn with_i143(mut self, x: MinOptMax) -> Self {
+ self.i143 = Some(x);
+ self
+ }
+
+ pub fn with_i144(mut self, x: MinOptMax) -> Self {
+ self.i144 = Some(x);
+ self
+ }
+
+ pub fn with_i145(mut self, x: MinOptMax) -> Self {
+ self.i145 = Some(x);
+ self
+ }
+
+ pub fn with_i150(mut self, x: MinOptMax) -> Self {
+ self.i150 = Some(x);
+ self
+ }
+
+ pub fn with_i151(mut self, x: MinOptMax) -> Self {
+ self.i151 = Some(x);
+ self
+ }
+
+ pub fn with_i152(mut self, x: MinOptMax) -> Self {
+ self.i152 = Some(x);
+ self
+ }
+
+ pub fn with_i153(mut self, x: MinOptMax) -> Self {
+ self.i153 = Some(x);
+ self
+ }
+
+ pub fn with_i154(mut self, x: MinOptMax) -> Self {
+ self.i154 = Some(x);
+ self
+ }
+
+ pub fn with_i155(mut self, x: MinOptMax) -> Self {
+ self.i155 = Some(x);
+ self
+ }
+
+ pub fn with_i160(mut self, x: MinOptMax) -> Self {
+ self.i160 = Some(x);
+ self
+ }
+
+ pub fn with_i161(mut self, x: MinOptMax) -> Self {
+ self.i161 = Some(x);
+ self
+ }
+
+ pub fn with_i162(mut self, x: MinOptMax) -> Self {
+ self.i162 = Some(x);
+ self
+ }
+
+ pub fn with_i163(mut self, x: MinOptMax) -> Self {
+ self.i163 = Some(x);
+ self
+ }
+
+ pub fn with_i164(mut self, x: MinOptMax) -> Self {
+ self.i164 = Some(x);
+ self
+ }
+
+ pub fn with_i165(mut self, x: MinOptMax) -> Self {
+ self.i165 = Some(x);
+ self
+ }
+
+ pub fn with_i170(mut self, x: MinOptMax) -> Self {
+ self.i170 = Some(x);
+ self
+ }
+
+ pub fn with_i171(mut self, x: MinOptMax) -> Self {
+ self.i171 = Some(x);
+ self
+ }
+
+ pub fn with_i172(mut self, x: MinOptMax) -> Self {
+ self.i172 = Some(x);
+ self
+ }
+
+ pub fn with_i173(mut self, x: MinOptMax) -> Self {
+ self.i173 = Some(x);
+ self
+ }
+
+ pub fn with_i174(mut self, x: MinOptMax) -> Self {
+ self.i174 = Some(x);
+ self
+ }
+
+ pub fn with_i175(mut self, x: MinOptMax) -> Self {
+ self.i175 = Some(x);
+ self
+ }
+
+ pub fn with_i180(mut self, x: MinOptMax) -> Self {
+ self.i180 = Some(x);
+ self
+ }
+
+ pub fn with_i181(mut self, x: MinOptMax) -> Self {
+ self.i181 = Some(x);
+ self
+ }
+
+ pub fn with_i182(mut self, x: MinOptMax) -> Self {
+ self.i182 = Some(x);
+ self
+ }
+
+ pub fn with_i183(mut self, x: MinOptMax) -> Self {
+ self.i183 = Some(x);
+ self
+ }
+
+ pub fn with_i184(mut self, x: MinOptMax) -> Self {
+ self.i184 = Some(x);
+ self
+ }
+
+ pub fn with_i185(mut self, x: MinOptMax) -> Self {
+ self.i185 = Some(x);
+ self
+ }
+
+ pub fn with_i190(mut self, x: MinOptMax) -> Self {
+ self.i190 = Some(x);
+ self
+ }
+
+ pub fn with_i191(mut self, x: MinOptMax) -> Self {
+ self.i191 = Some(x);
+ self
+ }
+
+ pub fn with_i192(mut self, x: MinOptMax) -> Self {
+ self.i192 = Some(x);
+ self
+ }
+
+ pub fn with_i193(mut self, x: MinOptMax) -> Self {
+ self.i193 = Some(x);
+ self
+ }
+
+ pub fn with_i194(mut self, x: MinOptMax) -> Self {
+ self.i194 = Some(x);
+ self
+ }
+
+ pub fn with_i195(mut self, x: MinOptMax) -> Self {
+ self.i195 = Some(x);
+ self
+ }
+
+ pub fn with_i200(mut self, x: MinOptMax) -> Self {
+ self.i200 = Some(x);
+ self
+ }
+
+ pub fn with_i201(mut self, x: MinOptMax) -> Self {
+ self.i201 = Some(x);
+ self
+ }
+
+ pub fn with_i202(mut self, x: MinOptMax) -> Self {
+ self.i202 = Some(x);
+ self
+ }
+
+ pub fn with_i203(mut self, x: MinOptMax) -> Self {
+ self.i203 = Some(x);
+ self
+ }
+
+ pub fn with_i204(mut self, x: MinOptMax) -> Self {
+ self.i204 = Some(x);
+ self
+ }
+
+ pub fn with_i205(mut self, x: MinOptMax) -> Self {
+ self.i205 = Some(x);
+ self
+ }
+
+ pub fn with_i210(mut self, x: MinOptMax) -> Self {
+ self.i210 = Some(x);
+ self
+ }
+
+ pub fn with_i211(mut self, x: MinOptMax) -> Self {
+ self.i211 = Some(x);
+ self
+ }
+
+ pub fn with_i212(mut self, x: MinOptMax) -> Self {
+ self.i212 = Some(x);
+ self
+ }
+
+ pub fn with_i213(mut self, x: MinOptMax) -> Self {
+ self.i213 = Some(x);
+ self
+ }
+
+ pub fn with_i214(mut self, x: MinOptMax) -> Self {
+ self.i214 = Some(x);
+ self
+ }
+
+ pub fn with_i215(mut self, x: MinOptMax) -> Self {
+ self.i215 = Some(x);
+ self
+ }
+
+ pub fn with_i220(mut self, x: MinOptMax) -> Self {
+ self.i220 = Some(x);
+ self
+ }
+
+ pub fn with_i221(mut self, x: MinOptMax) -> Self {
+ self.i221 = Some(x);
+ self
+ }
+
+ pub fn with_i222(mut self, x: MinOptMax) -> Self {
+ self.i222 = Some(x);
+ self
+ }
+
+ pub fn with_i223(mut self, x: MinOptMax) -> Self {
+ self.i223 = Some(x);
+ self
+ }
+
+ pub fn with_i224(mut self, x: MinOptMax) -> Self {
+ self.i224 = Some(x);
+ self
+ }
+
+ pub fn with_i225(mut self, x: MinOptMax) -> Self {
+ self.i225 = Some(x);
+ self
+ }
+
+ pub fn with_i230(mut self, x: MinOptMax) -> Self {
+ self.i230 = Some(x);
+ self
+ }
+
+ pub fn with_i231(mut self, x: MinOptMax) -> Self {
+ self.i231 = Some(x);
+ self
+ }
+
+ pub fn with_i232(mut self, x: MinOptMax) -> Self {
+ self.i232 = Some(x);
+ self
+ }
+
+ pub fn with_i233(mut self, x: MinOptMax) -> Self {
+ self.i233 = Some(x);
+ self
+ }
+
+ pub fn with_i234(mut self, x: MinOptMax) -> Self {
+ self.i234 = Some(x);
+ self
+ }
+
+ pub fn with_i235(mut self, x: MinOptMax) -> Self {
+ self.i235 = Some(x);
+ self
+ }
+
+ pub fn with_i240(mut self, x: MinOptMax) -> Self {
+ self.i240 = Some(x);
+ self
+ }
+
+ pub fn with_i241(mut self, x: MinOptMax) -> Self {
+ self.i241 = Some(x);
+ self
+ }
+
+ pub fn with_i242(mut self, x: MinOptMax) -> Self {
+ self.i242 = Some(x);
+ self
+ }
+
+ pub fn with_i243(mut self, x: MinOptMax) -> Self {
+ self.i243 = Some(x);
+ self
+ }
+
+ pub fn with_i244(mut self, x: MinOptMax) -> Self {
+ self.i244 = Some(x);
+ self
+ }
+
+ pub fn with_i245(mut self, x: MinOptMax) -> Self {
+ self.i245 = Some(x);
+ self
+ }
+
+ pub fn with_i250(mut self, x: MinOptMax) -> Self {
+ self.i250 = Some(x);
+ self
+ }
+
+ pub fn with_i251(mut self, x: MinOptMax) -> Self {
+ self.i251 = Some(x);
+ self
+ }
+
+ pub fn with_i252(mut self, x: MinOptMax) -> Self {
+ self.i252 = Some(x);
+ self
+ }
+
+ pub fn with_i253(mut self, x: MinOptMax) -> Self {
+ self.i253 = Some(x);
+ self
+ }
+
+ pub fn with_i254(mut self, x: MinOptMax) -> Self {
+ self.i254 = Some(x);
+ self
+ }
+
+ pub fn with_i255(mut self, x: MinOptMax) -> Self {
+ self.i255 = Some(x);
+ self
+ }
+ pub fn with_i260(mut self, x: MinOptMax) -> Self {
+ self.i260 = Some(x);
+ self
+ }
+
+ pub fn with_i261(mut self, x: MinOptMax) -> Self {
+ self.i261 = Some(x);
+ self
+ }
+
+ pub fn with_i262(mut self, x: MinOptMax) -> Self {
+ self.i262 = Some(x);
+ self
+ }
+
+ pub fn with_i263(mut self, x: MinOptMax) -> Self {
+ self.i263 = Some(x);
+ self
+ }
+
+ pub fn with_i264(mut self, x: MinOptMax) -> Self {
+ self.i264 = Some(x);
+ self
+ }
+
+ pub fn with_i265(mut self, x: MinOptMax) -> Self {
+ self.i265 = Some(x);
+ self
+ }
+
+ pub fn with_i270(mut self, x: MinOptMax) -> Self {
+ self.i270 = Some(x);
+ self
+ }
+
+ pub fn with_i271(mut self, x: MinOptMax) -> Self {
+ self.i271 = Some(x);
+ self
+ }
+
+ pub fn with_i272(mut self, x: MinOptMax) -> Self {
+ self.i272 = Some(x);
+ self
+ }
+
+ pub fn with_i273(mut self, x: MinOptMax) -> Self {
+ self.i273 = Some(x);
+ self
+ }
+
+ pub fn with_i274(mut self, x: MinOptMax) -> Self {
+ self.i274 = Some(x);
+ self
+ }
+
+ pub fn with_i275(mut self, x: MinOptMax) -> Self {
+ self.i275 = Some(x);
+ self
+ }
}
diff --git a/src/core/ort_engine.rs b/src/core/ort_engine.rs
index dd2994c..255a0ae 100644
--- a/src/core/ort_engine.rs
+++ b/src/core/ort_engine.rs
@@ -23,6 +23,7 @@ pub struct OrtTensorAttr {
/// ONNXRuntime Backend
#[derive(Debug)]
pub struct OrtEngine {
+ name: String,
session: Session,
device: Device,
inputs_minoptmax: Vec>,
@@ -129,6 +130,126 @@ impl OrtEngine {
(7, 3) => Self::_set_ixx(x, &config.i73, i, ii).unwrap_or(x_default),
(7, 4) => Self::_set_ixx(x, &config.i74, i, ii).unwrap_or(x_default),
(7, 5) => Self::_set_ixx(x, &config.i75, i, ii).unwrap_or(x_default),
+ (8, 0) => Self::_set_ixx(x, &config.i80, i, ii).unwrap_or(x_default),
+ (8, 1) => Self::_set_ixx(x, &config.i81, i, ii).unwrap_or(x_default),
+ (8, 2) => Self::_set_ixx(x, &config.i82, i, ii).unwrap_or(x_default),
+ (8, 3) => Self::_set_ixx(x, &config.i83, i, ii).unwrap_or(x_default),
+ (8, 4) => Self::_set_ixx(x, &config.i84, i, ii).unwrap_or(x_default),
+ (8, 5) => Self::_set_ixx(x, &config.i85, i, ii).unwrap_or(x_default),
+ (9, 0) => Self::_set_ixx(x, &config.i90, i, ii).unwrap_or(x_default),
+ (9, 1) => Self::_set_ixx(x, &config.i91, i, ii).unwrap_or(x_default),
+ (9, 2) => Self::_set_ixx(x, &config.i92, i, ii).unwrap_or(x_default),
+ (9, 3) => Self::_set_ixx(x, &config.i93, i, ii).unwrap_or(x_default),
+ (9, 4) => Self::_set_ixx(x, &config.i94, i, ii).unwrap_or(x_default),
+ (9, 5) => Self::_set_ixx(x, &config.i95, i, ii).unwrap_or(x_default),
+ (10, 0) => Self::_set_ixx(x, &config.i100, i, ii).unwrap_or(x_default),
+ (10, 1) => Self::_set_ixx(x, &config.i101, i, ii).unwrap_or(x_default),
+ (10, 2) => Self::_set_ixx(x, &config.i102, i, ii).unwrap_or(x_default),
+ (10, 3) => Self::_set_ixx(x, &config.i103, i, ii).unwrap_or(x_default),
+ (10, 4) => Self::_set_ixx(x, &config.i104, i, ii).unwrap_or(x_default),
+ (10, 5) => Self::_set_ixx(x, &config.i105, i, ii).unwrap_or(x_default),
+ (11, 0) => Self::_set_ixx(x, &config.i110, i, ii).unwrap_or(x_default),
+ (11, 1) => Self::_set_ixx(x, &config.i111, i, ii).unwrap_or(x_default),
+ (11, 2) => Self::_set_ixx(x, &config.i112, i, ii).unwrap_or(x_default),
+ (11, 3) => Self::_set_ixx(x, &config.i113, i, ii).unwrap_or(x_default),
+ (11, 4) => Self::_set_ixx(x, &config.i114, i, ii).unwrap_or(x_default),
+ (11, 5) => Self::_set_ixx(x, &config.i115, i, ii).unwrap_or(x_default),
+ (12, 0) => Self::_set_ixx(x, &config.i120, i, ii).unwrap_or(x_default),
+ (12, 1) => Self::_set_ixx(x, &config.i121, i, ii).unwrap_or(x_default),
+ (12, 2) => Self::_set_ixx(x, &config.i122, i, ii).unwrap_or(x_default),
+ (12, 3) => Self::_set_ixx(x, &config.i123, i, ii).unwrap_or(x_default),
+ (12, 4) => Self::_set_ixx(x, &config.i124, i, ii).unwrap_or(x_default),
+ (12, 5) => Self::_set_ixx(x, &config.i125, i, ii).unwrap_or(x_default),
+ (13, 0) => Self::_set_ixx(x, &config.i130, i, ii).unwrap_or(x_default),
+ (13, 1) => Self::_set_ixx(x, &config.i131, i, ii).unwrap_or(x_default),
+ (13, 2) => Self::_set_ixx(x, &config.i132, i, ii).unwrap_or(x_default),
+ (13, 3) => Self::_set_ixx(x, &config.i133, i, ii).unwrap_or(x_default),
+ (13, 4) => Self::_set_ixx(x, &config.i134, i, ii).unwrap_or(x_default),
+ (13, 5) => Self::_set_ixx(x, &config.i135, i, ii).unwrap_or(x_default),
+ (14, 0) => Self::_set_ixx(x, &config.i140, i, ii).unwrap_or(x_default),
+ (14, 1) => Self::_set_ixx(x, &config.i141, i, ii).unwrap_or(x_default),
+ (14, 2) => Self::_set_ixx(x, &config.i142, i, ii).unwrap_or(x_default),
+ (14, 3) => Self::_set_ixx(x, &config.i143, i, ii).unwrap_or(x_default),
+ (14, 4) => Self::_set_ixx(x, &config.i144, i, ii).unwrap_or(x_default),
+ (14, 5) => Self::_set_ixx(x, &config.i145, i, ii).unwrap_or(x_default),
+ (15, 0) => Self::_set_ixx(x, &config.i150, i, ii).unwrap_or(x_default),
+ (15, 1) => Self::_set_ixx(x, &config.i151, i, ii).unwrap_or(x_default),
+ (15, 2) => Self::_set_ixx(x, &config.i152, i, ii).unwrap_or(x_default),
+ (15, 3) => Self::_set_ixx(x, &config.i153, i, ii).unwrap_or(x_default),
+ (15, 4) => Self::_set_ixx(x, &config.i154, i, ii).unwrap_or(x_default),
+ (15, 5) => Self::_set_ixx(x, &config.i155, i, ii).unwrap_or(x_default),
+ (16, 0) => Self::_set_ixx(x, &config.i160, i, ii).unwrap_or(x_default),
+ (16, 1) => Self::_set_ixx(x, &config.i161, i, ii).unwrap_or(x_default),
+ (16, 2) => Self::_set_ixx(x, &config.i162, i, ii).unwrap_or(x_default),
+ (16, 3) => Self::_set_ixx(x, &config.i163, i, ii).unwrap_or(x_default),
+ (16, 4) => Self::_set_ixx(x, &config.i164, i, ii).unwrap_or(x_default),
+ (16, 5) => Self::_set_ixx(x, &config.i165, i, ii).unwrap_or(x_default),
+ (17, 0) => Self::_set_ixx(x, &config.i170, i, ii).unwrap_or(x_default),
+ (17, 1) => Self::_set_ixx(x, &config.i171, i, ii).unwrap_or(x_default),
+ (17, 2) => Self::_set_ixx(x, &config.i172, i, ii).unwrap_or(x_default),
+ (17, 3) => Self::_set_ixx(x, &config.i173, i, ii).unwrap_or(x_default),
+ (17, 4) => Self::_set_ixx(x, &config.i174, i, ii).unwrap_or(x_default),
+ (17, 5) => Self::_set_ixx(x, &config.i175, i, ii).unwrap_or(x_default),
+ (18, 0) => Self::_set_ixx(x, &config.i180, i, ii).unwrap_or(x_default),
+ (18, 1) => Self::_set_ixx(x, &config.i181, i, ii).unwrap_or(x_default),
+ (18, 2) => Self::_set_ixx(x, &config.i182, i, ii).unwrap_or(x_default),
+ (18, 3) => Self::_set_ixx(x, &config.i183, i, ii).unwrap_or(x_default),
+ (18, 4) => Self::_set_ixx(x, &config.i184, i, ii).unwrap_or(x_default),
+ (18, 5) => Self::_set_ixx(x, &config.i185, i, ii).unwrap_or(x_default),
+ (19, 0) => Self::_set_ixx(x, &config.i190, i, ii).unwrap_or(x_default),
+ (19, 1) => Self::_set_ixx(x, &config.i191, i, ii).unwrap_or(x_default),
+ (19, 2) => Self::_set_ixx(x, &config.i192, i, ii).unwrap_or(x_default),
+ (19, 3) => Self::_set_ixx(x, &config.i193, i, ii).unwrap_or(x_default),
+ (19, 4) => Self::_set_ixx(x, &config.i194, i, ii).unwrap_or(x_default),
+ (19, 5) => Self::_set_ixx(x, &config.i195, i, ii).unwrap_or(x_default),
+ (20, 0) => Self::_set_ixx(x, &config.i200, i, ii).unwrap_or(x_default),
+ (20, 1) => Self::_set_ixx(x, &config.i201, i, ii).unwrap_or(x_default),
+ (20, 2) => Self::_set_ixx(x, &config.i202, i, ii).unwrap_or(x_default),
+ (20, 3) => Self::_set_ixx(x, &config.i203, i, ii).unwrap_or(x_default),
+ (20, 4) => Self::_set_ixx(x, &config.i204, i, ii).unwrap_or(x_default),
+ (20, 5) => Self::_set_ixx(x, &config.i205, i, ii).unwrap_or(x_default),
+ (21, 0) => Self::_set_ixx(x, &config.i210, i, ii).unwrap_or(x_default),
+ (21, 1) => Self::_set_ixx(x, &config.i211, i, ii).unwrap_or(x_default),
+ (21, 2) => Self::_set_ixx(x, &config.i212, i, ii).unwrap_or(x_default),
+ (21, 3) => Self::_set_ixx(x, &config.i213, i, ii).unwrap_or(x_default),
+ (21, 4) => Self::_set_ixx(x, &config.i214, i, ii).unwrap_or(x_default),
+ (21, 5) => Self::_set_ixx(x, &config.i215, i, ii).unwrap_or(x_default),
+ (22, 0) => Self::_set_ixx(x, &config.i220, i, ii).unwrap_or(x_default),
+ (22, 1) => Self::_set_ixx(x, &config.i221, i, ii).unwrap_or(x_default),
+ (22, 2) => Self::_set_ixx(x, &config.i222, i, ii).unwrap_or(x_default),
+ (22, 3) => Self::_set_ixx(x, &config.i223, i, ii).unwrap_or(x_default),
+ (22, 4) => Self::_set_ixx(x, &config.i224, i, ii).unwrap_or(x_default),
+ (22, 5) => Self::_set_ixx(x, &config.i225, i, ii).unwrap_or(x_default),
+ (23, 0) => Self::_set_ixx(x, &config.i230, i, ii).unwrap_or(x_default),
+ (23, 1) => Self::_set_ixx(x, &config.i231, i, ii).unwrap_or(x_default),
+ (23, 2) => Self::_set_ixx(x, &config.i232, i, ii).unwrap_or(x_default),
+ (23, 3) => Self::_set_ixx(x, &config.i233, i, ii).unwrap_or(x_default),
+ (23, 4) => Self::_set_ixx(x, &config.i234, i, ii).unwrap_or(x_default),
+ (23, 5) => Self::_set_ixx(x, &config.i235, i, ii).unwrap_or(x_default),
+ (24, 0) => Self::_set_ixx(x, &config.i240, i, ii).unwrap_or(x_default),
+ (24, 1) => Self::_set_ixx(x, &config.i241, i, ii).unwrap_or(x_default),
+ (24, 2) => Self::_set_ixx(x, &config.i242, i, ii).unwrap_or(x_default),
+ (24, 3) => Self::_set_ixx(x, &config.i243, i, ii).unwrap_or(x_default),
+ (24, 4) => Self::_set_ixx(x, &config.i244, i, ii).unwrap_or(x_default),
+ (24, 5) => Self::_set_ixx(x, &config.i245, i, ii).unwrap_or(x_default),
+ (25, 0) => Self::_set_ixx(x, &config.i250, i, ii).unwrap_or(x_default),
+ (25, 1) => Self::_set_ixx(x, &config.i251, i, ii).unwrap_or(x_default),
+ (25, 2) => Self::_set_ixx(x, &config.i252, i, ii).unwrap_or(x_default),
+ (25, 3) => Self::_set_ixx(x, &config.i253, i, ii).unwrap_or(x_default),
+ (25, 4) => Self::_set_ixx(x, &config.i254, i, ii).unwrap_or(x_default),
+ (25, 5) => Self::_set_ixx(x, &config.i255, i, ii).unwrap_or(x_default),
+ (26, 0) => Self::_set_ixx(x, &config.i260, i, ii).unwrap_or(x_default),
+ (26, 1) => Self::_set_ixx(x, &config.i261, i, ii).unwrap_or(x_default),
+ (26, 2) => Self::_set_ixx(x, &config.i262, i, ii).unwrap_or(x_default),
+ (26, 3) => Self::_set_ixx(x, &config.i263, i, ii).unwrap_or(x_default),
+ (26, 4) => Self::_set_ixx(x, &config.i264, i, ii).unwrap_or(x_default),
+ (26, 5) => Self::_set_ixx(x, &config.i265, i, ii).unwrap_or(x_default),
+ (27, 0) => Self::_set_ixx(x, &config.i270, i, ii).unwrap_or(x_default),
+ (27, 1) => Self::_set_ixx(x, &config.i271, i, ii).unwrap_or(x_default),
+ (27, 2) => Self::_set_ixx(x, &config.i272, i, ii).unwrap_or(x_default),
+ (27, 3) => Self::_set_ixx(x, &config.i273, i, ii).unwrap_or(x_default),
+ (27, 4) => Self::_set_ixx(x, &config.i274, i, ii).unwrap_or(x_default),
+ (27, 5) => Self::_set_ixx(x, &config.i275, i, ii).unwrap_or(x_default),
_ => todo!(),
};
v_.push(x);
@@ -181,6 +302,7 @@ impl OrtEngine {
);
Ok(Self {
+ name: config.onnx_path.to_owned(),
session,
device,
inputs_minoptmax,
@@ -204,7 +326,7 @@ impl OrtEngine {
fp16_enable: bool,
engine_cache_enable: bool,
) -> Result<()> {
- let span = tracing::span!(tracing::Level::INFO, "OrtEngine-new");
+ let span = tracing::span!(tracing::Level::INFO, "OrtEngine-build_trt");
let _guard = span.enter();
// auto generate shapes
@@ -284,11 +406,16 @@ impl OrtEngine {
pub fn dry_run(&mut self) -> Result<()> {
if self.num_dry_run > 0 {
// pb
+ let name = std::path::Path::new(&self.name);
let pb = build_progress_bar(
self.num_dry_run as u64,
" DryRun",
- Some(&format!("{:?}", self.device)),
- crate::PROGRESS_BAR_STYLE_CYAN,
+ Some(
+ name.file_name()
+ .and_then(|x| x.to_str())
+ .unwrap_or_default(),
+ ),
+ crate::PROGRESS_BAR_STYLE_CYAN_2,
)?;
// dummy inputs
@@ -311,8 +438,16 @@ impl OrtEngine {
self.ts.clear();
// update
+ let name = std::path::Path::new(&self.name);
+ pb.set_message(format!(
+ "{} on {:?}",
+ name.file_name()
+ .and_then(|x| x.to_str())
+ .unwrap_or_default(),
+ self.device,
+ ));
pb.set_style(indicatif::ProgressStyle::with_template(
- crate::PROGRESS_BAR_STYLE_GREEN,
+ crate::PROGRESS_BAR_STYLE_FINISH,
)?);
pb.finish();
}
@@ -357,6 +492,7 @@ impl OrtEngine {
// inference
let t_run = std::time::Instant::now();
let outputs = self.session.run(&xs_[..])?;
+
let t_run = t_run.elapsed();
self.ts.add_or_push(1, t_run);
@@ -370,21 +506,32 @@ impl OrtEngine {
.zip(self.outputs_attrs.names.iter())
{
let y = &outputs[name.as_str()];
+
let y_ = match &dtype {
- TensorElementType::Float32 => y.try_extract_tensor::()?.view().into_owned(),
- TensorElementType::Float16 => y
- .try_extract_tensor::()?
- .view()
- .mapv(f16::to_f32)
- .into_owned(),
- TensorElementType::Int64 => y
- .try_extract_tensor::()?
- .view()
- .to_owned()
- .mapv(|x| x as f32)
- .into_owned(),
+ TensorElementType::Float32 => match y.try_extract_tensor::() {
+ Err(err) => {
+ tracing::error!("Error: {:?}. Output name: {:?}", err, name);
+ Array::zeros(0).into_dyn()
+ }
+ Ok(x) => x.view().into_owned(),
+ },
+ TensorElementType::Float16 => match y.try_extract_tensor::() {
+ Err(err) => {
+ tracing::error!("Error: {:?}. Output name: {:?}", err, name);
+ Array::zeros(0).into_dyn()
+ }
+ Ok(x) => x.view().mapv(f16::to_f32).into_owned(),
+ },
+ TensorElementType::Int64 => match y.try_extract_tensor::() {
+ Err(err) => {
+ tracing::error!("Error: {:?}. Output name: {:?}", err, name);
+ Array::zeros(0).into_dyn()
+ }
+ Ok(x) => x.view().to_owned().mapv(|x| x as f32).into_owned(),
+ },
_ => todo!(),
};
+
ys.push_kv(name.as_str(), X::from(y_))?;
}
let t_post = t_post.elapsed();
diff --git a/src/core/task.rs b/src/core/task.rs
index b85a00d..8090625 100644
--- a/src/core/task.rs
+++ b/src/core/task.rs
@@ -1,27 +1,186 @@
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, Ord, Eq, PartialOrd, PartialEq)]
pub enum Task {
- // vision
+ Untitled,
+
+ /// Image classification task.
+ /// Input: image
+ /// Output: a label representing the class of the image
ImageClassification,
+
+ /// Multi-label image tagging task.
+ /// Input: image
+ /// Output: multiple labels representing different categories in the image
+ ImageTagging,
+
+ /// Image captioning task, generating descriptions with different levels of detail.
+ /// Input: image
+ /// Output: a text description, `u8` represents the level of detail:
+ /// 0 for brief, 1 for detailed, 2 for more detailed
+ Caption(u8),
+
+ /// Region proposal task, detecting all objects in the image.
+ /// Input: image
+ /// Output: bounding boxes (bboxes)
+ RegionProposal,
+
+ /// Object detection task, detecting all objects in the image.
+ /// Input: image
+ /// Output: bounding boxes (bboxes), class labels, and optional scores for the detected objects
ObjectDetection,
+
+ /// Open set detection task, detecting and classifying objects in an image, with the ability to handle unseen or unknown objects.
+ /// Input: image
+ /// Output: bounding boxes, class labels (including an "unknown" category for unfamiliar objects), and detection scores
+ /// Open set detection task, with String query
+ OpenSetDetection(String),
+
+ /// Task for generating brief descriptions of dense regions in the image.
+ /// Input: image
+ /// Output: bounding boxes (bboxes), brief phrase labels, and optional scores for detected regions
+ DenseRegionCaption,
+
+ /// Keypoint detection task, detecting keypoints in an image.
+ /// This can include human body parts (e.g., hands, feet, joints) or other objects.
+ /// Input: image
+ /// Output: coordinates of detected keypoints
KeypointsDetection,
- RegisonProposal,
- PoseEstimation,
+
+ /// Semantic segmentation task, segmenting the image into different semantic regions.
+ /// Input: image
+ /// Output: per-pixel class labels indicating object or background
SemanticSegmentation,
+
+ /// Instance segmentation task, detecting and segmenting individual object instances.
+ /// Input: image
+ /// Output: pixel masks for each object instance
InstanceSegmentation,
+
+ /// Depth estimation task, predicting the distance of each pixel from the camera.
+ /// Input: image
+ /// Output: a depth map where each pixel has a depth value
DepthEstimation,
+
+ /// Surface normal prediction task, predicting the surface normal vector for each pixel.
+ /// Input: image
+ /// Output: a normal map where each pixel has a surface normal vector
SurfaceNormalPrediction,
- Image2ImageGeneration,
+
+ /// Image-to-image generation task, transforming one image into another.
+ /// Input: image
+ /// Output: a generated image
+ ImageToImageGeneration,
+
+ /// Text-to-image generation task, generating an image based on a text description.
+ /// Input: text
+ /// Output: a generated image
+ TextToImageGeneration,
+
+ /// Inpainting task, filling in missing or corrupted parts of an image.
+ /// Input: image with missing or corrupted regions
+ /// Output: a complete image with the missing parts filled in
Inpainting,
+
+ /// Super-resolution task, enhancing the resolution of an image.
+ /// Input: low-resolution image
+ /// Output: high-resolution image
SuperResolution,
+
+ /// Image denoising task, removing noise from an image.
+ /// Input: noisy image
+ /// Output: denoised image
Denoising,
- // vl
- Tagging,
- Captioning,
- DetailedCaptioning,
- MoreDetailedCaptioning,
- PhraseGrounding,
- Vqa,
+ /// Phrase grounding task, finding the region in an image corresponding to a text description.
+ /// Input: image and text
+ /// Output: image region and the corresponding phrase
+ /// caption to phrase grounding
+ CaptionToPhraseGrounding(String),
+
+ /// Referring expression segmentation task, segmenting objects in the image based on a text description.
+ /// Input: image and referring expression
+ /// Output: a segmentation mask for the object referred to by the text
+ ReferringExpressionSegmentation(String),
+
+ /// Region-to-segmentation task, similar to combining object detection with segmentation (e.g., YOLO + SAM).
+ /// Input: image and region proposals
+ /// Output: segmentation masks for the regions
+ /// Region, bbox: top-left, bottom-right
+ RegionToSegmentation(usize, usize, usize, usize),
+
+ /// Region-to-category classification task, classifying the object in a given region of the image.
+ /// Input: image and region
+ /// Output: class label for the region
+ /// Region, bbox: top-left, bottom-right
+ RegionToCategory(usize, usize, usize, usize),
+
+ /// Region-to-description task, generating a detailed description for a given region in the image.
+ /// Input: image and region
+ /// Output: a detailed textual description for the region
+ /// Region, bbox: top-left, bottom-right
+ RegionToDescription(usize, usize, usize, usize),
+
+ /// Visual question answering (VQA) task, answering questions related to an image.
+ /// Input: image and question text
+ /// Output: the answer to the question
+ Vqa(String),
+
+ /// Optical character recognition (OCR) task, recognizing text in an image.
+ /// Input: image
+ /// Output: recognized text
Ocr,
- Text2ImageGeneration,
+
+ /// OCR task with region information, recognizing text and returning its location in the image.
+ /// Input: image
+ /// Output: recognized text and its bounding box in the image
+ OcrWithRegion,
+}
+
+impl Task {
+ pub fn prompt_for_florence2(&self) -> anyhow::Result {
+ let prompt = match self {
+ Self::Untitled => anyhow::bail!("No task specified."),
+ Self::Caption(0) => "What does the image describe?".to_string(),
+ Self::Caption(1) => "Describe in detail what is shown in the image.".to_string(),
+ Self::Caption(2) => "Describe with a paragraph what is shown in the image.".to_string(),
+ Self::Ocr => "What is the text in the image?".to_string(),
+ Self::OcrWithRegion => "What is the text in the image, with regions?".to_string(),
+ Self::ObjectDetection => {
+ "Locate the objects with category name in the image.".to_string()
+ }
+ Self::DenseRegionCaption => {
+ "Locate the objects in the image, with their descriptions.".to_string()
+ }
+ Self::RegionProposal => "Locate the region proposals in the image.".to_string(),
+ Self::OpenSetDetection(text) => {
+ format!("Locate {} in the image.", text)
+ }
+ Self::CaptionToPhraseGrounding(text) => {
+ format!("Locate the phrases in the caption: {}", text)
+ }
+ Self::ReferringExpressionSegmentation(text) => {
+ format!("Locate {} in the image with mask", text)
+ }
+ Self::RegionToSegmentation(x0, y0, x1, y1) => {
+ format!(
+ "What is the polygon mask of region ",
+ x0, y0, x1, y1
+ )
+ }
+ Self::RegionToCategory(x0, y0, x1, y1) => {
+ format!(
+ "What is the region ?",
+ x0, y0, x1, y1
+ )
+ }
+ Self::RegionToDescription(x0, y0, x1, y1) => {
+ format!(
+ "What does the region describe?",
+ x0, y0, x1, y1
+ )
+ }
+ _ => anyhow::bail!("Unsupported task."),
+ };
+
+ Ok(prompt)
+ }
}
diff --git a/src/core/tokenizer_stream.rs b/src/core/tokenizer_stream.rs
index 5fb8025..495d69a 100644
--- a/src/core/tokenizer_stream.rs
+++ b/src/core/tokenizer_stream.rs
@@ -1,4 +1,4 @@
-// https://github.com/huggingface/candle/blob/2a8679509eb55232b37378442c4366343f6dcb11/candle-examples/src/token_output_stream.rs#L5
+// TODO: refactor
use anyhow::Result;
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
@@ -32,7 +32,6 @@ impl TokenizerStream {
}
}
- // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
pub fn next_token(&mut self, token: u32) -> Result