From f8cc0ba5879f27608051fbb192aa7349a2a4834f Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 13 Jan 2024 19:02:42 +0100 Subject: [PATCH 1/4] Fix the rotary embeddings for the new phi implementation. --- candle-transformers/src/models/phi.rs | 28 ++++++++++++--------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs index a635f3ce07..b2fd4e9cb2 100644 --- a/candle-transformers/src/models/phi.rs +++ b/candle-transformers/src/models/phi.rs @@ -38,6 +38,7 @@ impl Config { #[derive(Debug, Clone)] struct RotaryEmbedding { + dim: usize, sin: Tensor, cos: Tensor, } @@ -55,29 +56,24 @@ impl RotaryEmbedding { .to_dtype(DType::F32)? .reshape((cfg.max_position_embeddings, 1))?; let freqs = t.matmul(&inv_freq)?; + let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, + dim, + sin: emb.sin()?, + cos: emb.cos()?, }) } fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result { - let (_b_size, seqlen, _, _headdim) = xs.dims4()?; - let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?; - let rotary_dim = rotary_dim * 2; - let xs_rot = xs.i((.., .., .., ..rotary_dim))?; - let xs_pass = xs.i((.., .., .., rotary_dim..))?; + let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?; + let xs_rot = xs.i((.., .., .., ..self.dim))?; + let xs_pass = xs.i((.., .., .., self.dim..))?; let xs12 = xs_rot.chunk(2, D::Minus1)?; let (xs1, xs2) = (&xs12[0], &xs12[1]); - let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; - let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; - let xs_rot = Tensor::cat( - &[ - (xs1.broadcast_mul(&c)? - xs2.broadcast_mul(&s)?)?, - (xs1.broadcast_mul(&s)? + xs2.broadcast_mul(&c)?)?, - ], - D::Minus1, - )?; + let c = self.cos.narrow(0, seqlen_offset, seq_len)?; + let s = self.sin.narrow(0, seqlen_offset, seq_len)?; + let rotate_half = Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)?; + let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?; Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1) } } From 81422af12874128bc31e13dffe7b2e25939e3f16 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 13 Jan 2024 19:27:17 +0100 Subject: [PATCH 2/4] Match the activation. --- candle-transformers/src/models/phi.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs index b2fd4e9cb2..442ab456ca 100644 --- a/candle-transformers/src/models/phi.rs +++ b/candle-transformers/src/models/phi.rs @@ -93,7 +93,9 @@ impl MLP { Ok(Self { fc1, fc2, - act: cfg.hidden_act, + // Bypass the config activation for now to be in line with the mixformers + // implementation + act: Activation::Gelu, }) } } @@ -347,7 +349,7 @@ impl Model { Some(get_mask(seq_len, xs.device())?) }; for layer in self.layers.iter_mut() { - xs = layer.forward(&xs, mask.as_ref())? + xs = layer.forward(&xs, mask.as_ref())?; } xs.apply(&self.final_layernorm)? .narrow(1, seq_len - 1, 1)? From 816c114cf9844d4e7320de84351f7c27f00f3820 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 13 Jan 2024 19:32:00 +0100 Subject: [PATCH 3/4] KV cache fix. --- candle-transformers/src/models/phi.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs index 442ab456ca..ed0de80ce7 100644 --- a/candle-transformers/src/models/phi.rs +++ b/candle-transformers/src/models/phi.rs @@ -214,7 +214,7 @@ impl Attention { // Rotary embeddings. let seqlen_offset = match &self.kv_cache { None => 0, - Some((prev_k, _)) => prev_k.dim(1)?, + Some((prev_k, _)) => prev_k.dim(2)?, }; let query_states = self .rotary_emb From 1602cc4f24e255cb540b7efd4a6ac0a6370295f7 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 13 Jan 2024 19:35:26 +0100 Subject: [PATCH 4/4] Use the config activation function. --- candle-transformers/src/models/phi.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs index ed0de80ce7..8bf357e731 100644 --- a/candle-transformers/src/models/phi.rs +++ b/candle-transformers/src/models/phi.rs @@ -93,9 +93,9 @@ impl MLP { Ok(Self { fc1, fc2, - // Bypass the config activation for now to be in line with the mixformers - // implementation - act: Activation::Gelu, + // This does not match the mixformers implementation where Gelu is used rather than + // GeluNew. + act: cfg.hidden_act, }) } }