Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Jan 12, 2024
1 parent e63bb86 commit e06e8d0
Showing 1 changed file with 29 additions and 9 deletions.
38 changes: 29 additions & 9 deletions candle-metal-kernels/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,6 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
.unwrap();
}


command_buffer.commit();
command_buffer.wait_until_completed();

Expand All @@ -984,7 +983,6 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:

#[test]
fn random() {

fn calc_mean(data: &[f32]) -> f32 {
let sum = data.iter().sum::<f32>() as f32;
let count = data.len();
Expand All @@ -997,10 +995,14 @@ fn random() {
let count = data.len();
assert!(count > 0);

let variance = data.iter().map(|value| {
let diff = mean - (*value as f32);
diff * diff
}).sum::<f32>() / count as f32;
let variance = data
.iter()
.map(|value| {
let diff = mean - (*value as f32);
diff * diff
})
.sum::<f32>()
/ count as f32;

variance.sqrt()
}
Expand All @@ -1017,11 +1019,29 @@ fn random() {

macro_rules! validate_random {
($type:ty) => {
let results: Vec<f32> = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect();
let results: Vec<f32> = run_random::<$type>(
concat!("rand_uniform_", stringify!($type)),
seed,
length,
min,
max,
)
.into_iter()
.map(f32::from)
.collect();
results.iter().for_each(|v| assert!(*v >= min && *v <= max));
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);

let results: Vec<f32> = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect();
let results: Vec<f32> = run_random::<$type>(
concat!("rand_normal_", stringify!($type)),
seed,
length,
mean,
stddev,
)
.into_iter()
.map(f32::from)
.collect();
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
};
Expand All @@ -1030,4 +1050,4 @@ fn random() {
validate_random!(f32);
validate_random!(f16);
validate_random!(bf16);
}
}

0 comments on commit e06e8d0

Please sign in to comment.