Skip to content

Commit

Permalink
Add a larger test.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jan 28, 2025
1 parent 20ccdf7 commit 8c2f7af
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions candle-metal-kernels/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,9 @@ fn run_mlx_sort<T: Clone>(v: &[T], ncols: usize) -> Vec<u32> {

#[test]
fn mlx_sort() {
use rand::SeedableRng;
use rand_distr::Distribution;

let input: Vec<_> = (0..8).map(|v| v as f32).collect();
let result = run_mlx_sort(&input, 4);
assert_eq!(result, [0, 1, 2, 3, 0, 1, 2, 3]);
Expand All @@ -648,6 +651,21 @@ fn mlx_sort() {
assert_eq!(&result[400..600], out);
assert_eq!(&result[600..800], out);
assert_eq!(&result[800..], out);

// Multi-block test
let ncols = 16000;
let mut rng = rand::rngs::StdRng::seed_from_u64(299792458);
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
let input: Vec<f32> = (0..ncols * 16).map(|_| normal.sample(&mut rng)).collect();
let result = run_mlx_sort(&input, ncols);
for start in 0..16 {
let slice = &input[start * ncols..(start + 1) * ncols];
let result = &result[start * ncols..(start + 1) * ncols];
let mut perm: Vec<usize> = (0..ncols).collect();
perm.sort_by(|i1, i2| slice[*i1].total_cmp(&slice[*i2]));
let perm: Vec<_> = perm.into_iter().map(|v| v as u32).collect();
assert_eq!(perm, result);
}
}

#[test]
Expand Down

0 comments on commit 8c2f7af

Please sign in to comment.