Skip to content

Commit

Permalink
Update the Phi model to use the updated architecture.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jan 13, 2024
1 parent a46864b commit a80dd70
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 0 deletions.
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod mixformer;
pub mod mixtral;
pub mod mpt;
pub mod persimmon;
pub mod phi;
pub mod quantized_blip;
pub mod quantized_blip_text;
pub mod quantized_llama;
Expand Down
212 changes: 212 additions & 0 deletions candle-transformers/src/models/phi.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
#![allow(unused)]
use crate::models::with_tracing::{layer_norm, linear, Embedding, LayerNorm, Linear};
/// Phi model.
/// https://huggingface.co/microsoft/phi-2
/// There is an alternative implementation of the phi model in mixformers.rs.
/// This corresponds to the model update made with the following commit:
/// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use serde::Deserialize;

// https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) hidden_size: usize,
pub(crate) intermediate_size: usize,
pub(crate) num_hidden_layers: usize,
pub(crate) num_attention_heads: usize,
pub(crate) num_key_value_heads: Option<usize>,
pub(crate) hidden_act: Activation,
pub(crate) max_position_embeddings: usize,
pub(crate) layer_norm_eps: f64,
pub(crate) tie_word_embeddings: bool,
pub(crate) rope_theta: f32,
pub(crate) partial_rotary_factor: f64,
pub(crate) qk_layer_norm: bool,
}

impl Config {
fn num_key_value_heads(&self) -> usize {
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
}

fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
}

#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}

impl RotaryEmbedding {
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
let dim = (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
.to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
}

#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
fc1: Linear,
fc2: Linear,
act: Activation,
}

impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
Ok(Self {
fc1,
fc2,
act: cfg.hidden_act,
})
}
}

impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
}
}

#[derive(Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
dense: Linear,
q_layernorm: Option<LayerNorm>,
k_layernorm: Option<LayerNorm>,
rotary_emb: RotaryEmbedding,
softmax_scale: f64,
num_heads: usize,
num_kv_heads: usize,
span: tracing::Span,
}

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

impl Attention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads();
let head_dim = cfg.head_dim();
let q_proj = linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let dense = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("dense"))?;
// Alternative rope scalings are not supported.
let rotary_emb = RotaryEmbedding::new(cfg, vb.device())?;
let (q_layernorm, k_layernorm) = if cfg.qk_layer_norm {
let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?;
let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?;
(Some(q_layernorm), Some(k_layernorm))
} else {
(None, None)
};
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
Ok(Self {
q_proj,
k_proj,
v_proj,
dense,
q_layernorm,
k_layernorm,
rotary_emb,
softmax_scale,
num_heads,
num_kv_heads,
span: tracing::span!(tracing::Level::TRACE, "attention"),
})
}

fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_size, seq_len, _n_embd) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;

let query_states = match &self.q_layernorm {
None => query_states,
Some(ln) => query_states.apply(ln)?,
};
let key_states = match &self.k_layernorm {
None => key_states,
Some(ln) => key_states.apply(ln)?,
};

let query_states = query_states
.reshape((b_size, seq_len, self.num_heads))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_size, seq_len, self.num_kv_heads))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_size, seq_len, self.num_kv_heads))?
.transpose(1, 2)?;

// TODO: rotary embeddings.
// TODO: KV cache
// TODO: repeat kv.
let attn_weights = (query_states
.to_dtype(DType::F32)?
.matmul(&key_states.to_dtype(DType::F32)?.t()?)?
* self.softmax_scale)?;
let attn_weights = match mask {
None => attn_weights,
Some(mask) => masked_fill(
&attn_weights,
&mask.broadcast_left(b_size * self.num_heads)?,
f32::NEG_INFINITY,
)?,
};
let attn_weights =
candle_nn::ops::softmax_last_dim(&attn_weights)?.to_dtype(value_states.dtype())?;
let attn_output = attn_weights.matmul(&value_states)?;
let attn_output = attn_output
.transpose(1, 2)?
.reshape((b_size, seq_len, ()))?;
attn_output.apply(&self.dense)
}
}

#[derive(Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
input_layernorm: LayerNorm,
}

#[derive(Clone)]
pub struct Model {
embed_tokens: Embedding,
layers: Vec<DecoderLayer>,
final_layernorm: LayerNorm,
lm_head: Linear,
}

0 comments on commit a80dd70

Please sign in to comment.