Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the rotary embeddings for the new phi implementation. #1582

Merged
merged 4 commits into from
Jan 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions candle-transformers/src/models/phi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ impl Config {

#[derive(Debug, Clone)]
struct RotaryEmbedding {
dim: usize,
sin: Tensor,
cos: Tensor,
}
Expand All @@ -55,29 +56,24 @@ impl RotaryEmbedding {
.to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?;
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
dim,
sin: emb.sin()?,
cos: emb.cos()?,
})
}

fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, seqlen, _, _headdim) = xs.dims4()?;
let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
let rotary_dim = rotary_dim * 2;
let xs_rot = xs.i((.., .., .., ..rotary_dim))?;
let xs_pass = xs.i((.., .., .., rotary_dim..))?;
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
let xs_rot = xs.i((.., .., .., ..self.dim))?;
let xs_pass = xs.i((.., .., .., self.dim..))?;
let xs12 = xs_rot.chunk(2, D::Minus1)?;
let (xs1, xs2) = (&xs12[0], &xs12[1]);
let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
let xs_rot = Tensor::cat(
&[
(xs1.broadcast_mul(&c)? - xs2.broadcast_mul(&s)?)?,
(xs1.broadcast_mul(&s)? + xs2.broadcast_mul(&c)?)?,
],
D::Minus1,
)?;
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
let rotate_half = Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)?;
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
}
}
Expand All @@ -97,6 +93,8 @@ impl MLP {
Ok(Self {
fc1,
fc2,
// This does not match the mixformers implementation where Gelu is used rather than
// GeluNew.
act: cfg.hidden_act,
})
}
Expand Down Expand Up @@ -216,7 +214,7 @@ impl Attention {
// Rotary embeddings.
let seqlen_offset = match &self.kv_cache {
None => 0,
Some((prev_k, _)) => prev_k.dim(1)?,
Some((prev_k, _)) => prev_k.dim(2)?,
};
let query_states = self
.rotary_emb
Expand Down Expand Up @@ -351,7 +349,7 @@ impl Model {
Some(get_mask(seq_len, xs.device())?)
};
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, mask.as_ref())?
xs = layer.forward(&xs, mask.as_ref())?;
}
xs.apply(&self.final_layernorm)?
.narrow(1, seq_len - 1, 1)?
Expand Down
Loading