Skip to content

Commit

Permalink
Apply the rotary embeddings.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jan 13, 2024
1 parent 7d947ee commit d8cd8eb
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion candle-transformers/src/models/phi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,26 @@ impl RotaryEmbedding {
cos: freqs.cos()?,
})
}

fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, seqlen, _, _headdim) = xs.dims4()?;
let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
let rotary_dim = rotary_dim * 2;
let xs_rot = xs.i((.., .., .., ..rotary_dim))?;
let xs_pass = xs.i((.., .., .., rotary_dim..))?;
let xs12 = xs_rot.chunk(2, D::Minus1)?;
let (xs1, xs2) = (&xs12[0], &xs12[1]);
let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
let xs_rot = Tensor::cat(
&[
(xs1.broadcast_mul(&c)? - xs2.broadcast_mul(&s)?)?,
(xs1.broadcast_mul(&s)? + xs2.broadcast_mul(&c)?)?,
],
D::Minus1,
)?;
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -197,7 +217,12 @@ impl Attention {
None => 0,
Some((prev_k, _)) => prev_k.dim(1)?,
};
// TODO: rotary embeddings.
let query_states = self
.rotary_emb
.apply_rotary_emb(&query_states, seqlen_offset)?;
let key_states = self
.rotary_emb
.apply_rotary_emb(&key_states, seqlen_offset)?;

// KV cache.
let (key_states, value_states) = match &self.kv_cache {
Expand Down

0 comments on commit d8cd8eb

Please sign in to comment.