Skip to content

Commit

Permalink
Make the metal sdpa tests deterministic. (#2750)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Jan 28, 2025
1 parent da02b59 commit ab90194
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 75 deletions.
3 changes: 2 additions & 1 deletion candle-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ candle-metal-kernels = { workspace = true, optional = true }
anyhow = { workspace = true }
clap = { workspace = true }
rand = { workspace = true }
rand_distr = { workspace = true }
criterion = { workspace = true }

[features]
Expand All @@ -37,4 +38,4 @@ metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]

[[bench]]
name = "bench_main"
harness = false
harness = false
123 changes: 49 additions & 74 deletions candle-nn/tests/sdpa.rs
Original file line number Diff line number Diff line change
@@ -1,101 +1,98 @@
#[cfg(feature = "metal")]
mod metal_sdpa_tests {
#[test]
fn sdpa_full() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
use candle::{DType, Device, Result, Shape, Tensor};
use rand::SeedableRng;
use rand_distr::Distribution;
use std::ops::{Div, Mul};

fn randn<S: Into<Shape>>(
rng: &mut rand::rngs::StdRng,
shape: S,
dev: &Device,
) -> Result<Tensor> {
let shape = shape.into();
let elem_count = shape.elem_count();
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
let vs: Vec<f32> = (0..elem_count).map(|_| normal.sample(rng)).collect();
Tensor::from_vec(vs, &shape, dev)
}

#[test]
fn sdpa_full() -> Result<()> {
// Force seqlen = 100
const BS: usize = 4;
const R: usize = 4;
const L: usize = 4;
const DK: usize = 64;
const H: usize = 3;
let scale: f64 = f64::from(DK as u32).sqrt().recip();

let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;

let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;

let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let q = randn(&mut rng, (BS, H, R, DK), &device)?;
let k = randn(&mut rng, (BS, H, L, DK), &device)?;
let v = randn(&mut rng, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};

let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;

assert_eq!(ground_truth.shape(), sdpa_output.shape());

let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;

assert!(error <= 0.0005, "{}", error);

assert!(error <= 0.0004, "{}", error);
Ok(())
}

#[test]
fn sdpa_vector() -> candle::Result<()> {
use candle::{DType, Device, Tensor};

fn sdpa_vector() -> Result<()> {
// Allow vectorized, seqlen = 1
const BS: usize = 4;
const R: usize = 1;
const L: usize = 1;
const DK: usize = 64;
const H: usize = 3;
let scale: f64 = f64::from(DK as u32).sqrt().recip();

let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;

let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;

let mut rng = rand::rngs::StdRng::seed_from_u64(4242);
let q = randn(&mut rng, (BS, H, R, DK), &device)?;
let k = randn(&mut rng, (BS, H, L, DK), &device)?;
let v = randn(&mut rng, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};

let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;

assert_eq!(ground_truth.shape(), sdpa_output.shape());

let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;

assert!(error <= 0.0001, "{}", error);

assert!(error <= 0.000, "{}", error);
Ok(())
}

#[test]
fn sdpa_full_softcapping() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
use std::ops::{Div, Mul};

fn sdpa_full_softcapping() -> Result<()> {
// Allow vectorized, seqlen = 1
const BS: usize = 4;
const R: usize = 4;
const L: usize = 4;
const DK: usize = 64;
const H: usize = 3;
const SOFTCAP: f64 = 50.;
let scale: f64 = f64::from(DK as u32).sqrt().recip();

let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;

let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;

let mut rng = rand::rngs::StdRng::seed_from_u64(424242);
let q = randn(&mut rng, (BS, H, R, DK), &device)?;
let k = randn(&mut rng, (BS, H, L, DK), &device)?;
let v = randn(&mut rng, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(
Expand All @@ -107,40 +104,31 @@ mod metal_sdpa_tests {
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};

let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;

assert_eq!(ground_truth.shape(), sdpa_output.shape());

let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;

assert!(error <= 0.0005, "{}", error);

Ok(())
}

#[test]
fn sdpa_vector_softcapping() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
use std::ops::{Div, Mul};

fn sdpa_vector_softcapping() -> Result<()> {
// Allow vectorized, seqlen = 1
const BS: usize = 4;
const R: usize = 1;
const L: usize = 1;
const DK: usize = 64;
const H: usize = 3;
const SOFTCAP: f64 = 50.;
let scale: f64 = f64::from(DK as u32).sqrt().recip();

let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;

let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;

let mut rng = rand::rngs::StdRng::seed_from_u64(42424242);
let q = randn(&mut rng, (BS, H, R, DK), &device)?;
let k = randn(&mut rng, (BS, H, L, DK), &device)?;
let v = randn(&mut rng, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(
Expand All @@ -152,55 +140,42 @@ mod metal_sdpa_tests {
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};

let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;

assert_eq!(ground_truth.shape(), sdpa_output.shape());

let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;

assert!(error <= 0.0001, "{}", error);

Ok(())
}

#[test]
fn sdpa_vector_cross() -> candle::Result<()> {
use candle::{DType, Device, Tensor};

fn sdpa_vector_cross() -> Result<()> {
// Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1
const BS: usize = 4;
const R: usize = 1;
const L: usize = 24;
const DK: usize = 64;
const H: usize = 3;
let scale: f64 = f64::from(DK as u32).sqrt().recip();

let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;

let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;

let mut rng = rand::rngs::StdRng::seed_from_u64(4242424242);
let q = randn(&mut rng, (BS, H, R, DK), &device)?;
let k = randn(&mut rng, (BS, H, L, DK), &device)?;
let v = randn(&mut rng, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};

let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;

assert_eq!(ground_truth.shape(), sdpa_output.shape());

let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;

assert!(error <= 0.0013, "{}", error);

Ok(())
}
}

0 comments on commit ab90194

Please sign in to comment.