-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update the Phi model to use the updated architecture.
- Loading branch information
1 parent
a46864b
commit a80dd70
Showing
2 changed files
with
213 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
#![allow(unused)] | ||
use crate::models::with_tracing::{layer_norm, linear, Embedding, LayerNorm, Linear}; | ||
/// Phi model. | ||
/// https://huggingface.co/microsoft/phi-2 | ||
/// There is an alternative implementation of the phi model in mixformers.rs. | ||
/// This corresponds to the model update made with the following commit: | ||
/// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869 | ||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; | ||
use candle_nn::{Activation, VarBuilder}; | ||
use serde::Deserialize; | ||
|
||
// https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py | ||
#[derive(Debug, Clone, PartialEq, Deserialize)] | ||
pub struct Config { | ||
pub(crate) vocab_size: usize, | ||
pub(crate) hidden_size: usize, | ||
pub(crate) intermediate_size: usize, | ||
pub(crate) num_hidden_layers: usize, | ||
pub(crate) num_attention_heads: usize, | ||
pub(crate) num_key_value_heads: Option<usize>, | ||
pub(crate) hidden_act: Activation, | ||
pub(crate) max_position_embeddings: usize, | ||
pub(crate) layer_norm_eps: f64, | ||
pub(crate) tie_word_embeddings: bool, | ||
pub(crate) rope_theta: f32, | ||
pub(crate) partial_rotary_factor: f64, | ||
pub(crate) qk_layer_norm: bool, | ||
} | ||
|
||
impl Config { | ||
fn num_key_value_heads(&self) -> usize { | ||
self.num_key_value_heads.unwrap_or(self.num_attention_heads) | ||
} | ||
|
||
fn head_dim(&self) -> usize { | ||
self.hidden_size / self.num_attention_heads | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone)] | ||
struct RotaryEmbedding { | ||
sin: Tensor, | ||
cos: Tensor, | ||
} | ||
|
||
impl RotaryEmbedding { | ||
fn new(cfg: &Config, dev: &Device) -> Result<Self> { | ||
let dim = (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize; | ||
let inv_freq: Vec<_> = (0..dim) | ||
.step_by(2) | ||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32)) | ||
.collect(); | ||
let inv_freq_len = inv_freq.len(); | ||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; | ||
let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)? | ||
.to_dtype(DType::F32)? | ||
.reshape((cfg.max_position_embeddings, 1))?; | ||
let freqs = t.matmul(&inv_freq)?; | ||
Ok(Self { | ||
sin: freqs.sin()?, | ||
cos: freqs.cos()?, | ||
}) | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone)] | ||
#[allow(clippy::upper_case_acronyms)] | ||
struct MLP { | ||
fc1: Linear, | ||
fc2: Linear, | ||
act: Activation, | ||
} | ||
|
||
impl MLP { | ||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { | ||
let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?; | ||
let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?; | ||
Ok(Self { | ||
fc1, | ||
fc2, | ||
act: cfg.hidden_act, | ||
}) | ||
} | ||
} | ||
|
||
impl Module for MLP { | ||
fn forward(&self, xs: &Tensor) -> Result<Tensor> { | ||
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2) | ||
} | ||
} | ||
|
||
#[derive(Clone)] | ||
struct Attention { | ||
q_proj: Linear, | ||
k_proj: Linear, | ||
v_proj: Linear, | ||
dense: Linear, | ||
q_layernorm: Option<LayerNorm>, | ||
k_layernorm: Option<LayerNorm>, | ||
rotary_emb: RotaryEmbedding, | ||
softmax_scale: f64, | ||
num_heads: usize, | ||
num_kv_heads: usize, | ||
span: tracing::Span, | ||
} | ||
|
||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { | ||
let shape = mask.shape(); | ||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; | ||
let m = mask.where_cond(&on_true, on_false)?; | ||
Ok(m) | ||
} | ||
|
||
impl Attention { | ||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { | ||
let num_heads = cfg.num_attention_heads; | ||
let num_kv_heads = cfg.num_key_value_heads(); | ||
let head_dim = cfg.head_dim(); | ||
let q_proj = linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?; | ||
let k_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?; | ||
let v_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?; | ||
let dense = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("dense"))?; | ||
// Alternative rope scalings are not supported. | ||
let rotary_emb = RotaryEmbedding::new(cfg, vb.device())?; | ||
let (q_layernorm, k_layernorm) = if cfg.qk_layer_norm { | ||
let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?; | ||
let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?; | ||
(Some(q_layernorm), Some(k_layernorm)) | ||
} else { | ||
(None, None) | ||
}; | ||
let softmax_scale = 1f64 / (head_dim as f64).sqrt(); | ||
Ok(Self { | ||
q_proj, | ||
k_proj, | ||
v_proj, | ||
dense, | ||
q_layernorm, | ||
k_layernorm, | ||
rotary_emb, | ||
softmax_scale, | ||
num_heads, | ||
num_kv_heads, | ||
span: tracing::span!(tracing::Level::TRACE, "attention"), | ||
}) | ||
} | ||
|
||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> { | ||
let _enter = self.span.enter(); | ||
let (b_size, seq_len, _n_embd) = xs.dims3()?; | ||
let query_states = self.q_proj.forward(xs)?; | ||
let key_states = self.k_proj.forward(xs)?; | ||
let value_states = self.v_proj.forward(xs)?; | ||
|
||
let query_states = match &self.q_layernorm { | ||
None => query_states, | ||
Some(ln) => query_states.apply(ln)?, | ||
}; | ||
let key_states = match &self.k_layernorm { | ||
None => key_states, | ||
Some(ln) => key_states.apply(ln)?, | ||
}; | ||
|
||
let query_states = query_states | ||
.reshape((b_size, seq_len, self.num_heads))? | ||
.transpose(1, 2)?; | ||
let key_states = key_states | ||
.reshape((b_size, seq_len, self.num_kv_heads))? | ||
.transpose(1, 2)?; | ||
let value_states = value_states | ||
.reshape((b_size, seq_len, self.num_kv_heads))? | ||
.transpose(1, 2)?; | ||
|
||
// TODO: rotary embeddings. | ||
// TODO: KV cache | ||
// TODO: repeat kv. | ||
let attn_weights = (query_states | ||
.to_dtype(DType::F32)? | ||
.matmul(&key_states.to_dtype(DType::F32)?.t()?)? | ||
* self.softmax_scale)?; | ||
let attn_weights = match mask { | ||
None => attn_weights, | ||
Some(mask) => masked_fill( | ||
&attn_weights, | ||
&mask.broadcast_left(b_size * self.num_heads)?, | ||
f32::NEG_INFINITY, | ||
)?, | ||
}; | ||
let attn_weights = | ||
candle_nn::ops::softmax_last_dim(&attn_weights)?.to_dtype(value_states.dtype())?; | ||
let attn_output = attn_weights.matmul(&value_states)?; | ||
let attn_output = attn_output | ||
.transpose(1, 2)? | ||
.reshape((b_size, seq_len, ()))?; | ||
attn_output.apply(&self.dense) | ||
} | ||
} | ||
|
||
#[derive(Clone)] | ||
struct DecoderLayer { | ||
self_attn: Attention, | ||
mlp: MLP, | ||
input_layernorm: LayerNorm, | ||
} | ||
|
||
#[derive(Clone)] | ||
pub struct Model { | ||
embed_tokens: Embedding, | ||
layers: Vec<DecoderLayer>, | ||
final_layernorm: LayerNorm, | ||
lm_head: Linear, | ||
} |