Skip to content

Commit

Permalink
add dynamic position encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
ameroyer committed Feb 14, 2025
1 parent 7c2449f commit 9579726
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 11 deletions.
9 changes: 8 additions & 1 deletion candle-examples/examples/siglip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ struct Args {

#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,

#[arg(short, long)]
image_size: Option<usize>,
}

fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
Expand Down Expand Up @@ -81,7 +84,11 @@ pub fn main() -> anyhow::Result<()> {
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let images = load_images(&vec_imgs, config.vision_config.image_size)?.to_device(&device)?;
let images = load_images(
&vec_imgs,
args.image_size.unwrap_or(config.vision_config.image_size),
)?
.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = siglip::Model::new(&config, vb)?;
Expand Down
53 changes: 43 additions & 10 deletions candle-transformers/src/models/siglip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,9 @@ impl Encoder {
#[derive(Debug, Clone)]
struct VisionEmbeddings {
patch_embedding: candle_nn::Conv2d,
position_embedding: candle_nn::Embedding,
position_ids: Tensor,
position_embedding: Tensor,
patch_size: usize,
base_num_patches_per_side: usize,
}

impl VisionEmbeddings {
Expand All @@ -451,25 +452,57 @@ impl VisionEmbeddings {
conv2d_cfg,
vb.pp("patch_embedding"),
)?;
let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
let position_ids = Tensor::arange(0, num_patches as i64, vb.device())?;
let position_embedding =
candle_nn::embedding(num_patches, cfg.hidden_size(), vb.pp("position_embedding"))?;
let num_patches_per_side = cfg.image_size / cfg.patch_size;
let embedder = candle_nn::embedding(
num_patches_per_side.pow(2),
cfg.hidden_size(),
vb.pp("position_embedding"),
)?;
let position_embedding = embedder.embeddings();
let position_embedding = position_embedding
.reshape((
1,
num_patches_per_side,
num_patches_per_side,
cfg.hidden_size(),
))?
.permute((0, 3, 1, 2))?;
Ok(Self {
patch_embedding,
position_embedding,
position_ids,
patch_size: cfg.patch_size,
base_num_patches_per_side: num_patches_per_side,
})
}
}

impl Module for VisionEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
//embed tokens
let (_batch, _channels, _height, _width) = xs.dims4()?;
let embeddings = xs.apply(&self.patch_embedding)?;
let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
embeddings.broadcast_add(&position_embedding)
// interpolate position embeddings for the current image size (if needed)
let num_patches_h = _height / self.patch_size;
let num_patches_w = _width / self.patch_size;
let resized_position_embedding = if num_patches_w == self.base_num_patches_per_side
&& num_patches_h == self.base_num_patches_per_side
{
print!("No position embeddings interpolation");
self.position_embedding.clone()
} else {
print!(
"Interpolating position embeddings to ({}, {})",
num_patches_h, num_patches_w
);
self.position_embedding
.interpolate2d(num_patches_h, num_patches_w)?
};
// Add position embeddings to tokens and flatten from 2D patches to 1D sequence
let embeddings = embeddings
.broadcast_add(&resized_position_embedding)?
.flatten_from(2)?
.transpose(1, 2)?;
Ok(embeddings)
}
}

Expand Down

0 comments on commit 9579726

Please sign in to comment.