Skip to content

Commit

Permalink
Add blockwise fp8 linear
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jan 27, 2025
1 parent dcfc563 commit 28cb4bb
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 139 deletions.
78 changes: 0 additions & 78 deletions candle-core/src/cpu_backend/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,84 +61,6 @@ pub trait Map2 {
}
}

pub trait Map3 {
const OP: &'static str;
#[allow(clippy::too_many_arguments)]
fn f<T: WithDType>(
&self,
v1: &[T],
l1: &Layout,
v2: &[T],
l2: &Layout,
v3: &mut [T],
l3: &Layout,
s: Option<f64>,
) -> Result<()>;

#[allow(clippy::too_many_arguments)]
fn map(
&self,
v1: &C,
l1: &Layout,
v2: &C,
l2: &Layout,
v3: &mut C,
l3: &Layout,
s: Option<f64>,
) -> Result<()> {
let v3d = v3.dtype();
match (v1, v2, v3) {
(C::U8(v1), C::U8(v2), C::U8(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?),
(C::U32(v1), C::U32(v2), C::U32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?),
(C::I64(v1), C::I64(v2), C::I64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?),
(C::BF16(v1), C::BF16(v2), C::BF16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?),
(C::F16(v1), C::F16(v2), C::F16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?),
(C::F32(v1), C::F32(v2), C::F32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?),
(C::F64(v1), C::F64(v2), C::F64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?),
(C::F8E4M3(v1), C::F8E4M3(v2), C::F8E4M3(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?),
_ => Err(Error::DTypeMismatchBinaryOp3 {
lhs: v1.dtype(),
rhs: v2.dtype(),
c: v3d,
op: Self::OP,
}
.bt()),
}
}
}

pub trait Map2Alpha {
const OP: &'static str;
#[allow(clippy::too_many_arguments)]
fn f<T: WithDType>(
&self,
v1: &[T],
l1: &Layout,
v2: &[T],
l2: &Layout,
s: Option<f64>,
) -> Result<Vec<T>>;

#[allow(clippy::too_many_arguments)]
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout, s: Option<f64>) -> Result<C> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2, s)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2, s)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2, s)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2, s)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2, s)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2, s)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2, s)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}

pub trait Map2U8 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
Expand Down
3 changes: 1 addition & 2 deletions candle-core/src/cuda_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,7 @@ impl BackendDevice for CudaDevice {
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16
| DType::F8E4M3 => {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 | DType::F8E4M3 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
Expand Down
36 changes: 1 addition & 35 deletions candle-core/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,6 @@ impl ArgSort {
}
}

impl crate::CustomOp1 for ArgSort {
fn name(&self) -> &'static str {
"argsort"
}

fn cpu_fwd(
&self,
storage: &crate::CpuStorage,
layout: &crate::Layout,
) -> Result<(crate::CpuStorage, crate::Shape)> {
let sort_indexes = match storage {
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
crate::CpuStorage::I16(vs) => self.asort(vs, layout),
crate::CpuStorage::I32(vs) => self.asort(vs, layout),
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
crate::CpuStorage::F8E4M3(vs) => self.asort(vs, layout),
};
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
Ok((sort_indexes, layout.shape().into()))
}

#[cfg(feature = "cuda")]
mod cuda {
use super::*;
Expand Down Expand Up @@ -139,6 +113,7 @@ impl crate::CustomOp1 for ArgSort {
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
crate::CpuStorage::F8E4M3(vs) => self.asort(vs, layout),
};
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
Ok((sort_indexes, layout.shape().into()))
Expand Down Expand Up @@ -224,15 +199,6 @@ impl crate::CustomOp1 for ArgSort {
}
}

#[allow(unused)]
fn next_power_of_2(x: usize) -> usize {
let mut n = 1;
while n < x {
n *= 2
}
n
}

impl Tensor {
/// Returns the indices that sort the tensor along the last dimension.
///
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ serde = { workspace = true }
serde_json = { workspace = true }
serde_plain = { workspace = true }
tracing = { workspace = true }
float8 = { workspace = true }

[features]
default = []
Expand Down
3 changes: 2 additions & 1 deletion candle-transformers/src/models/deepseekv3/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mod ops;
pub mod model;
mod ops;
mod quant;
89 changes: 66 additions & 23 deletions candle-transformers/src/models/deepseekv3/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ use candle::{
shape::Dim, CpuStorage, CustomOp1, DType, Device, Error, IndexOp, Layout, Result, Shape,
Tensor, WithDType, D,
};
use candle_nn::{embedding, rms_norm, Activation, Embedding, Linear, Module, RmsNorm, VarBuilder};
use candle_nn::{embedding, rms_norm, Activation, Embedding, Module, RmsNorm, VarBuilder};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use serde::Deserialize;

use super::quant::{self, BlockwiseFP8Linear, QuantizedConfig};

struct NonZero {}

impl NonZero {
Expand Down Expand Up @@ -272,6 +274,7 @@ pub struct DeepSeekV2Config {
pub(crate) qk_nope_head_dim: usize,
pub(crate) n_group: usize,
pub(crate) topk_group: usize,
pub(crate) quantization_config: Option<QuantizedConfig>,
}

#[derive(Debug, Clone, Deserialize)]
Expand All @@ -285,7 +288,7 @@ pub enum ScaledRopeType {
#[serde(alias = "dynamic")]
Dynamic,
#[serde(alias = "linear")]
Linear,
BlockwiseFP8Linear,
}

impl FromStr for ScaledRopeType {
Expand All @@ -294,7 +297,7 @@ impl FromStr for ScaledRopeType {
match s {
"su" | "longrope" => Ok(Self::Su),
"yarn" => Ok(Self::Yarn),
"linear" => Ok(Self::Linear),
"linear" => Ok(Self::BlockwiseFP8Linear),
"dynamic" => Ok(Self::Dynamic),
_ => Err(candle::Error::Msg(
"Expected either `su` or `yarn` scaled RoPE type.".to_string(),
Expand Down Expand Up @@ -322,7 +325,7 @@ pub enum DeepSeekV2RopeScaling {
#[serde(rename = "type")]
scaling_type: ScaledRopeType,
},
LinearOrDynamic {
BlockwiseFP8LinearOrDynamic {
#[serde(rename = "type")]
scaling_type: ScaledRopeType,
factor: f64,
Expand Down Expand Up @@ -452,7 +455,7 @@ impl DeepSeekV2RotaryEmbedding {

pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
match &cfg.rope_scaling {
Some(DeepSeekV2RopeScaling::LinearOrDynamic {
Some(DeepSeekV2RopeScaling::BlockwiseFP8LinearOrDynamic {
scaling_type: _,
factor: _,
}) => candle::bail!("linear and dynamic rope are not implemented yet!"),
Expand Down Expand Up @@ -518,8 +521,12 @@ impl DeepSeekV2Config {
}

enum QProj {
Plain(Linear),
Lora { a: Linear, norm: RmsNorm, b: Linear },
Plain(BlockwiseFP8Linear),
Lora {
a: BlockwiseFP8Linear,
norm: RmsNorm,
b: BlockwiseFP8Linear,
},
}

impl QProj {
Expand All @@ -533,10 +540,10 @@ impl QProj {

struct Attention {
q: QProj,
kv_a_proj_with_mqa: Linear,
kv_a_proj_with_mqa: BlockwiseFP8Linear,
kv_a_layernorm: RmsNorm,
kv_b_proj: Linear,
o_proj: Linear,
kv_b_proj: BlockwiseFP8Linear,
o_proj: BlockwiseFP8Linear,
rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
cfg: DeepSeekV2Config,
q_head_dim: usize,
Expand All @@ -552,44 +559,59 @@ impl Attention {
let q_head_dim = cfg.q_head_dim();
let q = match cfg.q_lora_rank {
Some(lora_rank) => {
let a = candle_nn::linear_b(
let a = quant::blockwise_fp8_linear_b(
cfg.hidden_size,
lora_rank,
&cfg.quantization_config,
cfg.attention_bias,
None,
vb.pp("q_a_proj"),
)?;
let norm = rms_norm(lora_rank, cfg.rms_norm_eps, vb.pp("q_a_layernorm"))?;
let b = candle_nn::linear_no_bias(
let b = quant::blockwise_fp8_linear_b(
lora_rank,
cfg.num_attention_heads * q_head_dim,
&cfg.quantization_config,
false,
None,
vb.pp("q_b_proj"),
)?;
QProj::Lora { a, norm, b }
}
None => QProj::Plain(candle_nn::linear_no_bias(
None => QProj::Plain(quant::blockwise_fp8_linear_b(
cfg.hidden_size,
cfg.num_attention_heads * q_head_dim,
&cfg.quantization_config,
false,
None,
vb.pp("q_proj"),
)?),
};

let kv_a_proj_with_mqa = candle_nn::linear_b(
let kv_a_proj_with_mqa = quant::blockwise_fp8_linear_b(
cfg.hidden_size,
cfg.kv_lora_rank + cfg.qk_rope_head_dim,
&cfg.quantization_config,
cfg.attention_bias,
None,
vb.pp("kv_a_proj_with_mqa"),
)?;
let kv_a_layernorm = rms_norm(cfg.kv_lora_rank, cfg.rms_norm_eps, vb.pp("kv_a_layernorm"))?;
let kv_b_proj = candle_nn::linear_no_bias(
let kv_b_proj = quant::blockwise_fp8_linear_b(
cfg.kv_lora_rank,
cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim),
&cfg.quantization_config,
false,
None,
vb.pp("kv_b_proj"),
)?;

let o_proj = candle_nn::linear_b(
let o_proj = quant::blockwise_fp8_linear_b(
cfg.num_attention_heads * cfg.v_head_dim,
cfg.hidden_size,
&cfg.quantization_config,
cfg.attention_bias,
None,
vb.pp("o_proj"),
)?;

Expand Down Expand Up @@ -725,9 +747,9 @@ impl Attention {
}

struct Mlp {
gate: Linear,
up: Linear,
down: Linear,
gate: BlockwiseFP8Linear,
up: BlockwiseFP8Linear,
down: BlockwiseFP8Linear,
act: Activation,
}

Expand All @@ -742,9 +764,30 @@ impl Mlp {
let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size);

Ok(Self {
gate: candle_nn::linear_no_bias(hidden_size, intermediate_size, vb.pp("gate_proj"))?,
up: candle_nn::linear_no_bias(hidden_size, intermediate_size, vb.pp("up_proj"))?,
down: candle_nn::linear_no_bias(intermediate_size, hidden_size, vb.pp("down_proj"))?,
gate: quant::blockwise_fp8_linear_b(
hidden_size,
intermediate_size,
&cfg.quantization_config,
false,
None,
vb.pp("gate_proj"),
)?,
up: quant::blockwise_fp8_linear_b(
hidden_size,
intermediate_size,
&cfg.quantization_config,
false,
None,
vb.pp("up_proj"),
)?,
down: quant::blockwise_fp8_linear_b(
intermediate_size,
hidden_size,
&cfg.quantization_config,
false,
None,
vb.pp("down_proj"),
)?,
act: cfg.hidden_act,
})
}
Expand Down Expand Up @@ -1045,7 +1088,7 @@ impl DecoderLayer {
}

pub struct DeepSeekV2 {
lm_head: Linear,
lm_head: candle_nn::Linear,
embed_tokens: Embedding,
norm: RmsNorm,
layers: Vec<DecoderLayer>,
Expand Down
Loading

0 comments on commit 28cb4bb

Please sign in to comment.