Skip to content

Commit

Permalink
DType parametrization.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jan 28, 2025
1 parent 9966f5c commit 20ccdf7
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 14 deletions.
23 changes: 23 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,29 @@ const TERNARY: &str = include_str!("ternary.metal");
const UNARY: &str = include_str!("unary.metal");
const SDPA: &str = include_str!("scaled_dot_product_attention.metal");

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DType {
BF16,
F16,
F32,
I64,
U32,
U8,
}

impl DType {
fn size_in_bytes(&self) -> usize {
match self {
Self::U8 => 1,
Self::U32 => 4,
Self::I64 => 8,
Self::BF16 => 2,
Self::F16 => 2,
Self::F32 => 4,
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Affine,
Expand Down
50 changes: 37 additions & 13 deletions candle-metal-kernels/src/sort.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::utils::{BufferOffset, EncoderProvider};
use crate::{set_params, Kernels, MetalKernelError, Source};
use crate::{set_params, DType, Kernels, MetalKernelError, Source};
use metal::{Buffer, ComputeCommandEncoderRef, Device, MTLResourceOptions, MTLSize};

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -39,11 +39,23 @@ pub fn call_arg_sort(
Ok(())
}

fn mlx_dtype_str(dtype: DType) -> &'static str {
match dtype {
DType::U8 => "uint8",
DType::U32 => "uint32",
DType::I64 => "int64",
DType::F16 => "float16",
DType::BF16 => "bfloat16",
DType::F32 => "float32",
}
}

#[allow(clippy::too_many_arguments)]
pub fn multi_block_sort(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
dtype: DType,
bn: usize,
tn: usize,
nblocks: usize,
Expand All @@ -52,9 +64,10 @@ pub fn multi_block_sort(
src: BufferOffset,
dst: &Buffer,
) -> Result<(), MetalKernelError> {
let dtype_str = mlx_dtype_str(dtype);
// Do allocations
let el_count = nrows * ncols;
let bytes_len = 0; // TODO
let bytes_len = (el_count * dtype.size_in_bytes()) as u64;
let mut dev_vals_0 = device.new_buffer(bytes_len, MTLResourceOptions::StorageModePrivate);
let mut dev_vals_1 = device.new_buffer(bytes_len, MTLResourceOptions::StorageModePrivate);
let mut dev_idxs_0 =
Expand All @@ -70,7 +83,7 @@ pub fn multi_block_sort(
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
// Do blockwise sort
{
let name = format!("sort_mbsort_float32_uint32_bn{bn}_tn{tn}");
let name = format!("sort_mbsort_{dtype_str}_uint32_bn{bn}_tn{tn}");
let pipeline = kernels.load_pipeline(device, Source::MlxSort, name)?;
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
Expand Down Expand Up @@ -102,6 +115,8 @@ pub fn multi_block_sort(
let mut ping = false;
let mut merge_tiles = 2;
let n_thr_per_group = usize::min(nblocks + 1, 1024);
let partition_name = format!("partition_mbsort_{dtype_str}_uint32_bn{bn}_tn{tn}");
let merge_name = format!("merge_mbsort_float32_uint32_bn{bn}_tn{tn}");
while merge_tiles / 2 < nblocks {
let (dev_vals_in, dev_vals_out) = if ping {
(&mut dev_vals_1, &mut dev_vals_0)
Expand All @@ -116,8 +131,8 @@ pub fn multi_block_sort(
ping = !ping;
// Do partition
{
let name = format!("partition_mbsort_float32_uint32_bn{bn}_tn{tn}");
let pipeline = kernels.load_pipeline(device, Source::MlxSort, name)?;
let pipeline =
kernels.load_pipeline(device, Source::MlxSort, partition_name.clone())?;
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
Expand All @@ -144,8 +159,7 @@ pub fn multi_block_sort(
}
// Do merge
{
let name = format!("merge_mbsort_float32_uint32_bn{bn}_tn{tn}");
let pipeline = kernels.load_pipeline(device, Source::MlxSort, name)?;
let pipeline = kernels.load_pipeline(device, Source::MlxSort, merge_name.clone())?;
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
Expand Down Expand Up @@ -180,11 +194,19 @@ pub fn multi_block_sort(
&mut dev_idxs_0
};
// Copy output with appropriate strides
let copy_kernel = match dtype {
DType::U8 => crate::copy2d::U8,
DType::U32 => crate::copy2d::U32,
DType::I64 => crate::copy2d::I64,
DType::BF16 => crate::copy2d::BFLOAT,
DType::F16 => crate::copy2d::HALF,
DType::F32 => crate::copy2d::FLOAT,
};
crate::call_copy2d(
device,
encoder,
kernels,
crate::copy2d::Kernel("todo"),
copy_kernel,
dev_idxs_out,
dst,
/* d1 */ nrows,
Expand All @@ -202,14 +224,16 @@ pub fn block_sort(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
dtype: DType,
bn: usize,
tn: usize,
nrows: usize,
ncols: usize,
src: BufferOffset,
dst: &Buffer,
) -> Result<(), MetalKernelError> {
let name = format!("carg_block_sort_float32_uint32_bn{bn}_tn{tn}");
let dtype_str = mlx_dtype_str(dtype);
let name = format!("carg_block_sort_{dtype_str}_uint32_bn{bn}_tn{tn}");
let pipeline = kernels.load_pipeline(device, Source::MlxSort, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
Expand Down Expand Up @@ -247,26 +271,26 @@ pub fn call_mlx_arg_sort(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
dtype: DType,
nrows: usize,
ncols: usize,
size_of_dtype: usize,
src: BufferOffset,
dst: &Buffer,
) -> Result<(), MetalKernelError> {
let tn = 8;
let bn = match ncols.div_ceil(tn) {
257.. if size_of_dtype <= 4 => 512,
257.. if dtype.size_in_bytes() <= 4 => 512,
129.. => 256,
0..129 => 128,
};
let n_per_block = bn * tn;
let n_blocks = ncols.div_ceil(n_per_block);
if n_blocks > 1 {
multi_block_sort(
device, ep, kernels, bn, tn, n_blocks, nrows, ncols, src, dst,
device, ep, kernels, dtype, bn, tn, n_blocks, nrows, ncols, src, dst,
)?
} else {
block_sort(device, ep, kernels, bn, tn, nrows, ncols, src, dst)?
block_sort(device, ep, kernels, dtype, bn, tn, nrows, ncols, src, dst)?
}
Ok(())
}
2 changes: 1 addition & 1 deletion candle-metal-kernels/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,9 +620,9 @@ fn run_mlx_sort<T: Clone>(v: &[T], ncols: usize) -> Vec<u32> {
&device,
command_buffer,
&kernels,
DType::F32,
nrows,
ncols,
4,
BufferOffset::zero_offset(&input),
&output,
)
Expand Down

0 comments on commit 20ccdf7

Please sign in to comment.