From 95797267017396a14c804aa316a956159d7e60fc Mon Sep 17 00:00:00 2001 From: amelie Date: Fri, 14 Feb 2025 10:52:10 +0100 Subject: [PATCH 1/2] add dynamic position encoding --- candle-examples/examples/siglip/main.rs | 9 +++- candle-transformers/src/models/siglip.rs | 53 +++++++++++++++++++----- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/candle-examples/examples/siglip/main.rs b/candle-examples/examples/siglip/main.rs index be953c8764..bdd8f0969b 100644 --- a/candle-examples/examples/siglip/main.rs +++ b/candle-examples/examples/siglip/main.rs @@ -29,6 +29,9 @@ struct Args { #[arg(long, use_value_delimiter = true)] sequences: Option>, + + #[arg(short, long)] + image_size: Option, } fn load_image>(path: T, image_size: usize) -> anyhow::Result { @@ -81,7 +84,11 @@ pub fn main() -> anyhow::Result<()> { "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(), ], }; - let images = load_images(&vec_imgs, config.vision_config.image_size)?.to_device(&device)?; + let images = load_images( + &vec_imgs, + args.image_size.unwrap_or(config.vision_config.image_size), + )? + .to_device(&device)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }; let model = siglip::Model::new(&config, vb)?; diff --git a/candle-transformers/src/models/siglip.rs b/candle-transformers/src/models/siglip.rs index 932970ed3b..974880cb5f 100644 --- a/candle-transformers/src/models/siglip.rs +++ b/candle-transformers/src/models/siglip.rs @@ -434,8 +434,9 @@ impl Encoder { #[derive(Debug, Clone)] struct VisionEmbeddings { patch_embedding: candle_nn::Conv2d, - position_embedding: candle_nn::Embedding, - position_ids: Tensor, + position_embedding: Tensor, + patch_size: usize, + base_num_patches_per_side: usize, } impl VisionEmbeddings { @@ -451,25 +452,57 @@ impl VisionEmbeddings { conv2d_cfg, vb.pp("patch_embedding"), )?; - let num_patches = (cfg.image_size / cfg.patch_size).pow(2); - let position_ids = Tensor::arange(0, num_patches as i64, vb.device())?; - let position_embedding = - candle_nn::embedding(num_patches, cfg.hidden_size(), vb.pp("position_embedding"))?; + let num_patches_per_side = cfg.image_size / cfg.patch_size; + let embedder = candle_nn::embedding( + num_patches_per_side.pow(2), + cfg.hidden_size(), + vb.pp("position_embedding"), + )?; + let position_embedding = embedder.embeddings(); + let position_embedding = position_embedding + .reshape(( + 1, + num_patches_per_side, + num_patches_per_side, + cfg.hidden_size(), + ))? + .permute((0, 3, 1, 2))?; Ok(Self { patch_embedding, position_embedding, - position_ids, + patch_size: cfg.patch_size, + base_num_patches_per_side: num_patches_per_side, }) } } impl Module for VisionEmbeddings { fn forward(&self, xs: &Tensor) -> Result { + //embed tokens let (_batch, _channels, _height, _width) = xs.dims4()?; let embeddings = xs.apply(&self.patch_embedding)?; - let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?; - let position_embedding = self.position_embedding.forward(&self.position_ids)?; - embeddings.broadcast_add(&position_embedding) + // interpolate position embeddings for the current image size (if needed) + let num_patches_h = _height / self.patch_size; + let num_patches_w = _width / self.patch_size; + let resized_position_embedding = if num_patches_w == self.base_num_patches_per_side + && num_patches_h == self.base_num_patches_per_side + { + print!("No position embeddings interpolation"); + self.position_embedding.clone() + } else { + print!( + "Interpolating position embeddings to ({}, {})", + num_patches_h, num_patches_w + ); + self.position_embedding + .interpolate2d(num_patches_h, num_patches_w)? + }; + // Add position embeddings to tokens and flatten from 2D patches to 1D sequence + let embeddings = embeddings + .broadcast_add(&resized_position_embedding)? + .flatten_from(2)? + .transpose(1, 2)?; + Ok(embeddings) } } From f8d0a7e93ee00a65c5bc18d35959b0c9454b5dae Mon Sep 17 00:00:00 2001 From: amelie Date: Fri, 14 Feb 2025 10:57:15 +0100 Subject: [PATCH 2/2] remove debug messages --- candle-transformers/src/models/siglip.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/candle-transformers/src/models/siglip.rs b/candle-transformers/src/models/siglip.rs index 974880cb5f..b023c31f86 100644 --- a/candle-transformers/src/models/siglip.rs +++ b/candle-transformers/src/models/siglip.rs @@ -487,13 +487,8 @@ impl Module for VisionEmbeddings { let resized_position_embedding = if num_patches_w == self.base_num_patches_per_side && num_patches_h == self.base_num_patches_per_side { - print!("No position embeddings interpolation"); self.position_embedding.clone() } else { - print!( - "Interpolating position embeddings to ({}, {})", - num_patches_h, num_patches_w - ); self.position_embedding .interpolate2d(num_patches_h, num_patches_w)? };