diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 2e1816fd71..9cb1cf8b59 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -1,10 +1,12 @@ mod benchmarks; use criterion::criterion_main; + criterion_main!( benchmarks::affine::benches, benchmarks::matmul::benches, benchmarks::random::benches, + benchmarks::reduce::benches, benchmarks::where_cond::benches, benchmarks::conv_transpose2d::benches, benchmarks::qmatmul::benches, diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 579c5f3f0b..721b292d6f 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -3,6 +3,7 @@ pub(crate) mod conv_transpose2d; pub(crate) mod matmul; pub(crate) mod qmatmul; pub(crate) mod random; +pub(crate) mod reduce; pub(crate) mod unary; pub(crate) mod where_cond; diff --git a/candle-core/benches/benchmarks/reduce.rs b/candle-core/benches/benchmarks/reduce.rs new file mode 100644 index 0000000000..e0755a7080 --- /dev/null +++ b/candle-core/benches/benchmarks/reduce.rs @@ -0,0 +1,158 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use half::{bf16, f16}; +use std::time::Instant; + +fn run_sum(a: &Tensor) { + a.sum_keepdim(2).unwrap(); +} +fn run_arg_min(a: &Tensor) { + a.argmin_keepdim(2).unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + let (lo, up) = (-1000.0f32, 1000.0f32); + for device in handler.devices { + run_reduce(c, &device, (lo, up), false); + run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false); + run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false); + + run_arg_reduce(c, &device, (lo, up), false); + run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false); + run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false); + + run_reduce(c, &device, (lo, up), true); + run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true); + run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true); + + run_arg_reduce(c, &device, (lo, up), true); + run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true); + run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true); + } +} + +fn run_reduce( + c: &mut Criterion, + device: &Device, + (lo, up): (T, T), + strided: bool, +) { + let b = 1; + let m = 1024; + let k = 1024; + + let a = if strided { + Tensor::rand(lo, up, (b, m, k), &device) + .unwrap() + .transpose(0, 2) + .unwrap() + } else { + Tensor::rand(lo, up, (b, m, k), &device).unwrap() + }; + + let flops = b * m * k * T::DTYPE.size_in_bytes(); + + let name = match T::DTYPE { + DType::F32 => { + if strided { + "reduce_f32_strided" + } else { + "reduce_f32" + } + } + DType::F16 => { + if strided { + "reduce_f16_strided" + } else { + "reduce_f16" + } + } + DType::BF16 => { + if strided { + "reduce_bf16_strided" + } else { + "reduce_bf16" + } + } + _ => "unknown", + }; + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run_sum(black_box(&a)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn run_arg_reduce( + c: &mut Criterion, + device: &Device, + (lo, up): (T, T), + strided: bool, +) { + let b = 1; + let m = 1024; + let k = 1024; + + let a = if strided { + Tensor::rand(lo, up, (b, m, k), &device) + .unwrap() + .transpose(0, 2) + .unwrap() + } else { + Tensor::rand(lo, up, (b, m, k), &device).unwrap() + }; + + let flops = b * m * k * T::DTYPE.size_in_bytes(); + + let name = match T::DTYPE { + DType::F32 => { + if strided { + "arg_reduce_f32_strided" + } else { + "arg_reduce_f32" + } + } + DType::F16 => { + if strided { + "arg_reduce_f16_strided" + } else { + "arg_reduce_f16" + } + } + DType::BF16 => { + if strided { + "arg_reduce_bf16_strided" + } else { + "arg_reduce_bf16" + } + } + _ => "unknown", + }; + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run_arg_min(black_box(&a)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 25523a40c6..43869a0c3a 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -2,7 +2,6 @@ use crate::{DType, Result}; use candle_metal_kernels::Kernels; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; -use std::ffi::c_void; use std::path::Path; use std::sync::{Arc, Mutex, RwLock}; @@ -236,7 +235,7 @@ impl MetalDevice { pub fn new_buffer_with_data(&self, data: &[T]) -> Result> { let size = core::mem::size_of_val(data) as NSUInteger; let new_buffer = self.device.new_buffer_with_data( - data.as_ptr() as *const c_void, + data.as_ptr().cast(), size, MTLResourceOptions::StorageModeManaged, ); diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 70a512bc8e..433188cff7 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -265,6 +265,7 @@ impl BackendStorage for MetalStorage { fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { let device = self.device.clone(); + let src_stride = layout.stride(); let src_dims = layout.shape().dims(); // Source dims and strides with the sum dims at the end. @@ -278,13 +279,72 @@ impl BackendStorage for MetalStorage { stride.push(src_stride[dim_idx]); } } + for &dim_idx in sum_dims.iter() { dims.push(src_dims[dim_idx]); stride.push(src_stride[dim_idx]); } - // The reduction loop requires the shared array to be properly initialized and for - // this we want the number of threads to be a power of two. + let reduction_shape = Shape::from(dims.clone()); + + if layout.is_contiguous() && reduction_shape.is_contiguous(&stride) { + let (name, check_empty, return_index) = match (op, self.dtype) { + (ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false), + (ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false), + (ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false), + (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true), + (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true), + (ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false), + (ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false), + (ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false), + (ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true), + (ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true), + (ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false), + (ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false), + (ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false), + (ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true), + (ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true), + (ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false), + (ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false), + (ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false), + (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true), + (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true), + (ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false), + (ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false), + (ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false), + (ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true), + (ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true), + (ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false), + (ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false), + (ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false), + (ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true), + (ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true), + (k, dtype) => { + crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented") + } + }; + if check_empty && layout.shape().elem_count() == 0 { + Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? + } + let dtype = if return_index { DType::U32 } else { self.dtype }; + let buffer = device.new_buffer(dst_el, dtype, "reduce")?; + let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, layout, self.dtype); + candle_metal_kernels::call_reduce_contiguous( + &device.device, + &command_buffer, + &device.kernels, + name, + src_dims, + dst_el, + src, + &buffer, + ) + .map_err(MetalError::from)?; + + return Ok(Self::new(buffer, device, dst_el, dtype)); + } + let (name, check_empty, return_index) = match (op, self.dtype) { (ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false), (ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false), @@ -316,7 +376,7 @@ impl BackendStorage for MetalStorage { (ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false), (ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true), (ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true), - (k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"), + (k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"), }; if check_empty && layout.shape().elem_count() == 0 { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index edc5209bcc..6de44f9c6f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -5,14 +5,12 @@ use metal::{ use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; - pub mod mlx_gemm; pub mod sort; pub mod utils; -pub use utils::BufferOffset; - pub use mlx_gemm::{call_mlx_gemm, GemmDType}; pub use sort::{call_arg_sort, call_mlx_arg_sort}; +pub use utils::BufferOffset; use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider}; const AFFINE: &str = include_str!("affine.metal"); @@ -176,7 +174,7 @@ pub enum MetalKernelError { LockError(String), #[error("Error while loading library: {0}")] LoadLibraryError(String), - #[error("Error while loading function: {0:?}")] + #[error("Error while loading function: {0}")] LoadFunctionError(String), #[error("Failed to create compute function")] FailedToCreateComputeFunction, @@ -635,19 +633,31 @@ pub fn call_reduce_contiguous( ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, - length: usize, + shape: &[usize], out_length: usize, input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { + let length = shape.iter().product::(); + let num_dims = shape.len(); + let work_per_threadgroup = length / out_length; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; - let elements_to_sum = length / out_length; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, &input, output)); + set_params!( + encoder, + ( + length, + num_dims, + shape, + work_per_threadgroup, + &input, + output + ) + ); let thread_group_count = MTLSize { width: out_length as u64, @@ -657,9 +667,8 @@ pub fn call_reduce_contiguous( let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), - (elements_to_sum as u64).div_ceil(2), - ) - .next_power_of_two(); + (work_per_threadgroup / 2).next_power_of_two() as NSUInteger, + ); let thread_group_size = MTLSize { width, @@ -686,8 +695,9 @@ pub fn call_reduce_strided( output: &Buffer, ) -> Result<(), MetalKernelError> { let length: usize = shape.iter().product(); + let num_dims = shape.len(); + let work_per_threadgroup = length / out_length; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; - let elements_to_sum = length / out_length; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); @@ -695,7 +705,15 @@ pub fn call_reduce_strided( set_params!( encoder, - (shape.len(), shape, strides, elements_to_sum, &input, output) + ( + length, + num_dims, + shape, + strides, + work_per_threadgroup, + &input, + output + ) ); let thread_group_count = MTLSize { @@ -706,16 +724,14 @@ pub fn call_reduce_strided( let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), - elements_to_sum as u64, - ) - .next_power_of_two(); + (work_per_threadgroup / 2).next_power_of_two() as NSUInteger, + ); let thread_group_size = MTLSize { width, height: 1, depth: 1, }; - encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); @@ -729,11 +745,13 @@ pub fn call_last_softmax( kernels: &Kernels, kernel_name: &'static str, length: usize, - elements_to_sum: usize, + elements: usize, input: &Buffer, input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { + let work_per_threadgroup = elements; + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); @@ -741,29 +759,27 @@ pub fn call_last_softmax( set_params!( encoder, - (length, elements_to_sum, (input, input_offset), output) + (length, work_per_threadgroup, (input, input_offset), output) ); - let out_length = length / elements_to_sum; + let out_length = length / work_per_threadgroup; let thread_group_count = MTLSize { - width: out_length as u64, + width: out_length as NSUInteger, height: 1, depth: 1, }; let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), - elements_to_sum as u64, - ) - .next_power_of_two(); + (work_per_threadgroup / 2).next_power_of_two() as NSUInteger, + ); let thread_group_size = MTLSize { width, height: 1, depth: 1, }; - encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index e009ca1d6a..291c81e631 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -1,14 +1,41 @@ #include +#include using namespace metal; -#define MAX(x, y) ((x) > (y) ? (x) : (y)) -#define MIN(x, y) ((x) < (y) ? (x) : (y)) +METAL_FUNC uint nonzero(uint n) { + return n == 0 ? 1 : n; +} + +template +constexpr uint nonzero() { + return N == 0 ? 1 : N; +} + +template +constexpr ushort granularity() { + return nonzero::value>(); +} + +METAL_FUNC uint next_p2(uint x) { + return 1 << (32 - clz(x - 1)); +} + +METAL_FUNC uint prev_p2(uint x) { + return 1 << (31 - clz(x)); +} + +constant uint MAX_SHARED_MEM = 32767; + +template +METAL_FUNC uint max_shared_mem(uint n) { + return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T))); +} METAL_FUNC uint get_strided_index( uint idx, - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides + constant const uint &num_dims, + constant const size_t *dims, + constant const size_t *strides ) { uint strided_i = 0; for (uint d = 0; d < num_dims; d++) { @@ -19,289 +46,904 @@ METAL_FUNC uint get_strided_index( return strided_i; } -constant int THREADGROUP_SIZE = 2048; +struct Divide { + template + METAL_FUNC T operator()(T a, T b) { return a / b; } + METAL_FUNC float operator()(float a, float b) { return fast::divide(a, b); } + METAL_FUNC half operator()(half a, half b) { return divide(a, b); } + #if defined(__HAVE_BFLOAT__) + METAL_FUNC bfloat operator()(bfloat a, bfloat b) { return static_cast(fast::divide(a, b)); } + #endif +}; + +struct Exp { + template + METAL_FUNC T operator()(T a) { return fast::exp(a); } + METAL_FUNC float operator()(float a) { return fast::exp(a); } + METAL_FUNC half operator()(half a) { return exp(a); } + #if defined(__HAVE_BFLOAT__) + METAL_FUNC bfloat operator()(bfloat a) { return static_cast(fast::exp(a)); } + #endif +}; + + +// Keeps track of the index of the value in the reduction operation (argmin, argmax, etc.) +// and the value itself. The index is also used to break ties in the reduction operation. +template +struct indexed { + uint i; + T val; + + constexpr indexed() threadgroup = default; +}; + +template +struct is_indexed_type { + static constant constexpr bool value = false; +}; + +template +constexpr constant bool is_indexed_t = is_indexed_type::value; + +template +struct is_indexed_type> { + static constant constexpr bool value = true; +}; + +template +constexpr constant bool not_indexed_t = !is_indexed_t; template -METAL_FUNC void argmin( - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides, - constant size_t &el_to_sum_per_block, - device const T *src, - device uint *dst, - uint id, - uint tid, - uint dst_id, - uint block_dim, - threadgroup T *shared_memory, - threadgroup uint *shared_indices -) { - bool notset = true; - // Elements summed in this block range from dst_id * el_to_sum_per_block - // to (dst_id + 1) * el_to_sum_per_block. - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = start_idx + el_to_sum_per_block; - size_t idx = start_idx + tid; - while (idx < stop_idx) { - // TODO: Fast version for the contiguous case. - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); - if (notset || src[strided_i] < shared_memory[tid]) { - shared_memory[tid] = src[strided_i]; - /* Assume that the reduction takes place over the last dimension which is contiguous. */ - shared_indices[tid] = idx % dims[num_dims - 1]; - notset = false; - } - idx += block_dim; - } +constexpr METAL_FUNC bool operator<(indexed lhs, indexed rhs) { + return lhs.val < rhs.val || (lhs.val == rhs.val && lhs.i < rhs.i); +} - threadgroup_barrier(mem_flags::mem_none); - // reduction in shared memory - for (uint s = block_dim / 2; s > 0; s >>= 1) { - if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { - shared_indices[tid] = shared_indices[tid + s]; - shared_memory[tid] = shared_memory[tid + s]; - } \ - threadgroup_barrier(mem_flags::mem_none); +template +constexpr METAL_FUNC bool operator>(indexed lhs, indexed rhs) { + return lhs.val > rhs.val || (lhs.val == rhs.val && lhs.i < rhs.i); +} + +template +struct _numeric_limits_impl> { + static constexpr METAL_FUNC indexed lowest() { + return indexed{ 0, numeric_limits::lowest() }; } - if (tid == 0) { - dst[dst_id] = shared_indices[0]; + + static constexpr METAL_FUNC indexed max() { + return indexed{ 0, numeric_limits::max() }; } +}; + +#if __METAL_VERSION__ >= 220 +METAL_FUNC int64_t simd_shuffle_down(int64_t data, uint16_t delta) { + return as_type(simd_shuffle_down(as_type(data), delta)); } +#endif -#define ARGMIN(NAME, T, MAXVALUE) \ -kernel void NAME( \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device uint *dst, \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup T shared_memory[THREADGROUP_SIZE]; \ - threadgroup uint shared_indices[THREADGROUP_SIZE]; \ - shared_memory[tid] = MAXVALUE; \ - shared_indices[tid] = 0xFFFFFFFF; \ - argmin(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \ -} \ +#if defined(__HAVE_BFLOAT__) +// Metal does not have simd_shuffle_down for bfloat16 +METAL_FUNC bfloat simd_shuffle_down(bfloat value, ushort delta) { + return as_type(simd_shuffle_down(as_type(value), delta)); +} +#endif + +template +METAL_FUNC indexed simd_shuffle_down(indexed iv, ushort delta) { + return indexed { + simd_shuffle_down(iv.i, delta), + simd_shuffle_down(iv.val, delta) + }; +} template -METAL_FUNC void argmax( - constant size_t & num_dims, - constant size_t * dims, - constant size_t * strides, - constant size_t & el_to_sum_per_block, - device const T * src, - device uint * dst, - uint id, - uint tid, - uint dst_id, - uint block_dim, - threadgroup T * shared_memory, - threadgroup uint * shared_indices - ) { - // Elements summed in this block range from dst_id * el_to_sum_per_block - // to (dst_id + 1) * el_to_sum_per_block. - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = start_idx + el_to_sum_per_block; - size_t idx = start_idx + tid; - bool notset = true; - while (idx < stop_idx) { - // TODO: Fast version for the contiguous case. - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); - if (notset || shared_memory[tid] < src[strided_i]) { - shared_memory[tid] = src[strided_i]; - shared_indices[tid] = idx % dims[num_dims - 1]; - notset = false; - } - idx += block_dim; +struct Sum { + static constexpr METAL_FUNC T init() { + return 0; + } + static METAL_FUNC T simd_op(T a) { + return simd_sum(a); } - threadgroup_barrier(mem_flags::mem_none); + template + METAL_FUNC V operator()(V a, V b) { + return a + b; + } +}; - // reduction in shared memory - for (uint s = block_dim / 2; s > 0; s >>= 1) { - if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { - shared_indices[tid] = shared_indices[tid + s]; - shared_memory[tid] = shared_memory[tid + s]; - } - threadgroup_barrier(mem_flags::mem_none); +template +struct Mul { + static constexpr METAL_FUNC T init() { + return 1; + } + static METAL_FUNC T simd_op(T a) { + return simd_product(a); } - // Thread 0 writes the result of the reduction - if (tid == 0) { - dst[dst_id] = shared_indices[0]; + template + METAL_FUNC V operator()(V a, V b) { + return a * b; } - } +}; -#define ARGMAX(NAME, T, MINVALUE) \ -kernel void NAME( \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device uint *dst, \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup T shared_memory[THREADGROUP_SIZE]; \ - threadgroup uint shared_indices[THREADGROUP_SIZE]; \ - shared_memory[tid] = MINVALUE; \ - shared_indices[tid] = 0xFFFFFFFF; \ - argmax(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \ -} \ +template +struct Min { + static constexpr METAL_FUNC T init() { + return numeric_limits::max(); + } + static METAL_FUNC T simd_op(T a) { + return simd_min(a); + } + + template + METAL_FUNC V operator()(V a, V b) { return a < b ? a : b; } + + METAL_FUNC float operator()(float a, float b) { return fast::min(a, b); } + METAL_FUNC half operator()(half a, half b) { return min(a, b); } + METAL_FUNC uint operator()(uint a, uint b) { return min(a, b); } + METAL_FUNC uchar operator()(uchar a, uchar b) { return min(a, b); } + + #if __METAL_VERSION__ >= 220 + METAL_FUNC long operator()(long a, long b) { return min(a, b); } + #endif + + #if defined(__HAVE_BFLOAT__) + METAL_FUNC bfloat operator()(bfloat a, bfloat b) { return static_cast(fast::min(static_cast(a), static_cast(b))); } + #endif +}; template -METAL_FUNC void reduce( - constant size_t & num_dims, - constant size_t * dims, - constant size_t * strides, - constant size_t & el_to_sum_per_block, - device const T * src, - device T * dst, - uint id, - uint tid, - uint dst_id, - uint block_dim, - threadgroup T * shared_memory, - T (*fn)(T, T) -) { - // Elements summed in this block range from dst_id * el_to_sum_per_block - // to (dst_id + 1) * el_to_sum_per_block. - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = start_idx + el_to_sum_per_block; - size_t idx = start_idx + tid; - while (idx < stop_idx) { - // TODO: Fast version for the contiguous case. - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); - T x = shared_memory[tid]; - T y = src[strided_i]; - shared_memory[tid] = fn(x, y); - idx += block_dim; +struct Max { + static constexpr METAL_FUNC T init() { + return numeric_limits::lowest(); + } + static METAL_FUNC T simd_op(T a) { + return simd_max(a); } - threadgroup_barrier(mem_flags::mem_none); + template + METAL_FUNC V operator()(V a, V b) { return a > b ? a : b; } - // reduction in shared memory - for (uint s = block_dim / 2; s > 0; s >>= 1) { - if (tid < s) { - T x = shared_memory[tid]; - T y = shared_memory[tid + s]; - shared_memory[tid] = fn(x, y); + METAL_FUNC float operator()(float a, float b) { return fast::max(a, b); } + METAL_FUNC half operator()(half a, half b) { return max(a, b); } + METAL_FUNC uint operator()(uint a, uint b) { return max(a, b); } + METAL_FUNC uchar operator()(uchar a, uchar b) { return max(a, b); } + + #if __METAL_VERSION__ >= 220 + METAL_FUNC long operator()(long a, long b) { return max(a, b); } + #endif + + #if defined(__HAVE_BFLOAT__) + METAL_FUNC bfloat operator()(bfloat a, bfloat b) { return static_cast(fast::max(static_cast(a), static_cast(b))); } + #endif +}; + +template +constexpr constant bool is_simd_t = __is_valid_simdgroup_type::value; + +template +struct is_valid_simd_type { + static constant constexpr bool value = false; +}; + +template +constexpr constant bool is_valid_simd_t = is_valid_simd_type::value; + +template +struct is_valid_simd_type>> { + static constant constexpr bool value = true; +}; + +template +struct is_valid_simd_type, typename metal::enable_if_t>> { + static constant constexpr bool value = true; +}; + +#if __METAL_VERSION__ >= 220 +template <> +struct is_valid_simd_type { + static constant constexpr bool value = true; +}; +#endif + +#if defined(__HAVE_BFLOAT__) +template <> +struct is_valid_simd_type { + static constant constexpr bool value = true; +}; +#endif + +template +struct is_simd_op { + static constant constexpr bool value = false; +}; +template +struct is_simd_op, typename metal::enable_if_t>> { + static constant constexpr bool value = true; +}; +template +struct is_simd_op, typename metal::enable_if_t>> { + static constant constexpr bool value = true; +}; +template +struct is_simd_op, typename metal::enable_if_t>> { + static constant constexpr bool value = true; +}; +template +struct is_simd_op, typename metal::enable_if_t>> { + static constant constexpr bool value = true; +}; + +// Helper struct for applying operators. +// The overloaded operator() function is used to apply an operator to two values. +template +struct operation; + +// Specialization for scalar values. +template +struct operation { + OP op; + + METAL_FUNC T operator()(T a, T b) { + return op(a, b); + } +}; + +// Specialization for indexed values. +template +struct operation> { + OP op; + + METAL_FUNC indexed operator()(indexed a, indexed b) { + return op(a, b); + } + METAL_FUNC indexed operator()(indexed a, T b, uint idx) { + return this->operator()(a, indexed{ idx, b }); + } +}; + +// Load elements from global memory into shared memory. +// Handles both indexed and non-indexed types by using operate. +template< + typename T, + typename R, + typename OP, + ushort BLOCKSIZE, + bool STRIDED = false, + typename _E = void +> +struct loader; + + +// Contiguous +template< + typename T, + typename R, + typename OP, + ushort BLOCKSIZE +> +struct loader>> { + operation operate; + + METAL_FUNC R operator()( + R value, + constant uint &src_numel, + constant uint &el_per_block, + device const T *src, + const uint offset, + const uint tid + ) { + uint idx = tid + offset; + const uint stop_idx = min(el_per_block + offset, src_numel); + + #pragma clang loop unroll(full) + for (uint i = idx; i < stop_idx; i += BLOCKSIZE) { + value = operate(value, src[i]); } - threadgroup_barrier(mem_flags::mem_none); + return value; } - if (tid == 0) { - dst[dst_id] = shared_memory[0]; + METAL_FUNC R operator()( + R value, + constant uint &src_numel, + constant uint &num_dims, + constant size_t *dims, + constant size_t *strides, + constant uint &el_per_block, + device const T *src, + const uint offset, + const uint tid + ) { + return this->operator()(value, src_numel, el_per_block, src, offset, tid); } -} +}; -#define REDUCE(FN, NAME, T, START) \ -METAL_FUNC T NAME##_##op(T x, T y) { return FN; } \ -kernel void NAME( \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device T *dst, \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup T shared_memory[THREADGROUP_SIZE]; \ - shared_memory[tid] = START; \ - reduce(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, NAME##_##op); \ -} \ +// Strided +template< + typename T, + typename R, + typename OP, + ushort BLOCKSIZE +> +struct loader>> { + operation operate; -template -METAL_FUNC void softmax( - constant size_t & src_numel, - constant size_t & el_to_sum_per_block, - device const T * src, - device T * dst, - uint id, - uint tid, - uint dst_id, - uint block_dim, - threadgroup float * shared_memory -) { - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); - size_t idx = start_idx + tid; - float tmp = -INFINITY; - while (idx < stop_idx) { - tmp = MAX(tmp, float(src[idx])); - idx += block_dim; + METAL_FUNC R operator()( + R value, + constant uint &src_numel, + constant uint &num_dims, + constant size_t *dims, + constant size_t *strides, + constant uint &el_per_block, + device const T *src, + const uint offset, + const uint tid + ) { + const uint idx = tid + offset; + const uint stop_idx = min(el_per_block + offset, src_numel); + + #pragma clang loop unroll(full) + for (uint i = idx; i < stop_idx; i += BLOCKSIZE) { + value = operate(value, src[get_strided_index(i, num_dims, dims, strides)]); + } + return value; } - shared_memory[tid] = tmp; +}; - threadgroup_barrier(mem_flags::mem_threadgroup); +// Indexed contiguous +template< + typename T, + typename R, + typename OP, + ushort BLOCKSIZE +> +struct loader>> { + operation operate; - for (uint s = block_dim / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]);\ + METAL_FUNC R operator()( + R value, + constant uint &src_numel, + constant uint &num_dims, + constant size_t *dims, + constant size_t *strides, + constant uint &el_per_block, + device const T *src, + const uint offset, + const uint tid + ) { + const uint thread_id = tid + offset; + const uint stop_idx = min(el_per_block + offset, src_numel); + + #pragma clang loop unroll(full) + for (uint i = thread_id; i < stop_idx; i += BLOCKSIZE) { + value = operate(value, src[i], i % dims[num_dims - 1]); } - threadgroup_barrier(mem_flags::mem_threadgroup); + return value; } +}; - /* wait for shared_memory[0] to be filled */ - threadgroup_barrier(mem_flags::mem_threadgroup); +// Indexed strided +template< + typename T, + typename R, + typename OP, + ushort BLOCKSIZE +> +struct loader>> { + operation operate; - float _max = shared_memory[0]; + METAL_FUNC R operator()( + R value, + constant uint &src_numel, + constant uint &num_dims, + constant size_t *dims, + constant size_t *strides, + constant uint &el_per_block, + device const T *src, + const uint offset, + const uint tid + ) { + const uint thread_id = tid + offset; + const uint stop_idx = min(el_per_block + offset, src_numel); - /* prevent tid=0 from overwriting _max before other threads have written it */ - threadgroup_barrier(mem_flags::mem_threadgroup); - shared_memory[tid] = 0; + #pragma clang loop unroll(full) + for (uint i = thread_id; i < stop_idx; i += BLOCKSIZE) { + value = operate(value, src[get_strided_index(i, num_dims, dims, strides)], i % dims[num_dims - 1]); + } + return value; + } +}; - idx = start_idx + tid; - while (idx < stop_idx) { - const float val = exp(float(src[idx]) - _max); - dst[idx] = T(val); - shared_memory[tid] += val; - idx += block_dim; +template< + typename OP, + ushort BLOCKSIZE, + typename T, + typename _E = void +> +struct simdgroup_reducer; + +// Specialization for built-in simd operations. +template +struct simdgroup_reducer::value && is_valid_simd_t>> { + METAL_FUNC T operator()(T value) { + return OP::simd_op(value); } - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint s = block_dim / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_memory[tid] += shared_memory[tid + s]; +}; + +// Specialization for custom (non-built-in) simd operations. +template +struct simdgroup_reducer::value && is_valid_simd_t>> { + operation op; + + METAL_FUNC T operator()(T value) { + if (BLOCKSIZE >= 32) value = op(value, simd_shuffle_down(value, 16)); + if (BLOCKSIZE >= 16) value = op(value, simd_shuffle_down(value, 8)); + if (BLOCKSIZE >= 8) value = op(value, simd_shuffle_down(value, 4)); + if (BLOCKSIZE >= 4) value = op(value, simd_shuffle_down(value, 2)); + if (BLOCKSIZE >= 2) value = op(value, simd_shuffle_down(value, 1)); + return value; + } +}; + +template +struct block_reducer { + simdgroup_reducer simd_reduce; + operation operate; + threadgroup T *shared; + + block_reducer(threadgroup T shared[BLOCKSIZE]) { + this->shared = shared; + } + + METAL_FUNC T operator()(T value, const uint tid) { + if (BLOCKSIZE >= 64) { + // Only store in threadgroup shared memory if needed. + shared[tid] = value; + // Threadgroup barrier is needed to ensure that all threads have written to shared memory + threadgroup_barrier(mem_flags::mem_none); } - threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma clang loop unroll(full) + for (ushort s = BLOCKSIZE / 2; s >= 64; s >>= 1) { + if (tid < s) shared[tid] = operate(shared[tid], shared[tid + s]); + threadgroup_barrier(mem_flags::mem_none); + } + if (tid < 32) { + // Last shared memory reduce can be done without tid < s check. + if (BLOCKSIZE >= 64) { + value = operate(shared[tid], shared[tid + 32]); + simdgroup_barrier(mem_flags::mem_none); + } + // Remaining 32 threads can be reduced with simdgroup_reduce. + value = simd_reduce(value); + } + return value; } +}; - const T inv_acc = T(1.0 / shared_memory[0]); - idx = start_idx + tid; - while (idx < stop_idx) { - dst[idx] *= inv_acc; - idx += block_dim; +// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris +template< + typename T, + typename R, + typename OP, + ushort BLOCKSIZE, + bool STRIDED = false +> +METAL_FUNC void reduce( + constant uint &src_numel, + constant uint &num_dims, + constant size_t *dims, + constant size_t *strides, + constant uint &el_per_block, + device const T *src, + device R *dst, + threadgroup R shared[BLOCKSIZE], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]] +) { + loader load; + block_reducer reduce(shared); + + // Calcluate offset for the threadgroup of current thread + const uint offset = dst_id * el_per_block; + + // Load with reduction from global memory into shared memory + auto value = load( + OP::init(), + src_numel, + num_dims, + dims, + strides, + el_per_block, + src, + offset, + tid + ); + // Complete reduction + R result = reduce(value, tid); + + if (tid == 0) dst[dst_id] = result; +} + +#define reduce_case(OP, T, R, N) \ +case N: { \ + threadgroup R shared[N]; \ + reduce, N, STRIDED>( \ + src_numel, \ + num_dims, \ + dims, \ + strides, \ + el_per_block, \ + src, \ + dst, \ + shared, \ + tid, \ + dst_id); \ + break; \ +} + +#define ARG(...) __VA_ARGS__ + +#define impl_reduce_inner(OP, NAME, T) \ +kernel void NAME( \ + constant uint &src_numel, \ + constant uint &num_dims, \ + constant size_t *dims, \ + constant uint &el_per_block, \ + device const T *src, \ + device T *dst, \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + constant size_t *strides = {}; \ + const bool STRIDED = false; \ + switch (max_shared_mem(block_dim)) { \ + reduce_case(OP, ARG(T), ARG(T), 2048); \ + reduce_case(OP, ARG(T), ARG(T), 1024); \ + reduce_case(OP, ARG(T), ARG(T), 512); \ + reduce_case(OP, ARG(T), ARG(T), 256); \ + reduce_case(OP, ARG(T), ARG(T), 128); \ + reduce_case(OP, ARG(T), ARG(T), 64); \ + reduce_case(OP, ARG(T), ARG(T), 32); \ + reduce_case(OP, ARG(T), ARG(T), 16); \ + reduce_case(OP, ARG(T), ARG(T), 8); \ + reduce_case(OP, ARG(T), ARG(T), 4); \ + reduce_case(OP, ARG(T), ARG(T), 2); \ + reduce_case(OP, ARG(T), ARG(T), 1); \ + } \ +} + + +#define impl_reduce_strided(OP, NAME, T) \ +kernel void NAME##_strided( \ + constant uint &src_numel, \ + constant uint &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant uint &el_per_block, \ + device const T *src, \ + device T *dst, \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + const bool STRIDED = true; \ + switch (max_shared_mem(block_dim)) { \ + reduce_case(OP, ARG(T), ARG(T), 2048); \ + reduce_case(OP, ARG(T), ARG(T), 1024); \ + reduce_case(OP, ARG(T), ARG(T), 512); \ + reduce_case(OP, ARG(T), ARG(T), 256); \ + reduce_case(OP, ARG(T), ARG(T), 128); \ + reduce_case(OP, ARG(T), ARG(T), 64); \ + reduce_case(OP, ARG(T), ARG(T), 32); \ + reduce_case(OP, ARG(T), ARG(T), 16); \ + reduce_case(OP, ARG(T), ARG(T), 8); \ + reduce_case(OP, ARG(T), ARG(T), 4); \ + reduce_case(OP, ARG(T), ARG(T), 2); \ + reduce_case(OP, ARG(T), ARG(T), 1); \ + } \ +} + +#define impl_reduce(OP, NAME, T) \ +impl_reduce_inner(OP, NAME, T) \ +impl_reduce_strided(OP, NAME, T) \ + +template< + typename T, + typename ReductionOp, + ushort BLOCKSIZE, + bool STRIDED = false +> +METAL_FUNC void reduce( + constant uint &src_numel, + constant uint &num_dims, + constant size_t *dims, + constant size_t *strides, + constant uint &el_per_block, + device const T *src, + device uint *dst, + threadgroup indexed shared[BLOCKSIZE], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]] +) { + using I = indexed; + loader, ReductionOp, BLOCKSIZE, STRIDED> load; + block_reducer reduce(shared); + + // Calcluate offset for the threadgroup of current thread + const uint offset = dst_id * el_per_block; + + // Load with reduction from global memory into shared memory + indexed value = load( + ReductionOp::init(), + src_numel, + num_dims, + dims, + strides, + el_per_block, + src, + offset, + tid + ); + + // Complete reduction + I result = reduce(value, tid); + + // Return index of reduce result + if (tid == 0) dst[dst_id] = result.i; +} + +#define arg_reduce_case(OP, T, N) \ +case N: { \ + using I = indexed; \ + threadgroup I shared[N]; \ + reduce, N, STRIDED>( \ + src_numel, \ + num_dims, \ + dims, \ + strides, \ + el_per_block, \ + src, \ + dst, \ + shared, \ + tid, \ + dst_id); \ + break; \ +} + +#define impl_arg_reduce_inner(OP, NAME, T) \ +kernel void NAME( \ + constant uint &src_numel, \ + constant uint &num_dims, \ + constant size_t *dims, \ + constant uint &el_per_block, \ + device const T *src, \ + device uint *dst, \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + constant size_t *strides = {}; \ + const bool STRIDED = false; \ + switch (max_shared_mem>(block_dim)) { \ + arg_reduce_case(OP, ARG(T), 1024); \ + arg_reduce_case(OP, ARG(T), 512); \ + arg_reduce_case(OP, ARG(T), 256); \ + arg_reduce_case(OP, ARG(T), 128); \ + arg_reduce_case(OP, ARG(T), 64); \ + arg_reduce_case(OP, ARG(T), 32); \ + arg_reduce_case(OP, ARG(T), 16); \ + arg_reduce_case(OP, ARG(T), 8); \ + arg_reduce_case(OP, ARG(T), 4); \ + arg_reduce_case(OP, ARG(T), 2); \ + arg_reduce_case(OP, ARG(T), 1); \ + } \ +} \ + + +#define impl_arg_reduce_strided(OP, NAME, T) \ +kernel void NAME##_strided( \ + constant uint &src_numel, \ + constant uint &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant uint &el_per_block, \ + device const T *src, \ + device uint *dst, \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + const bool STRIDED = true; \ + const bool INDEXED = true; \ + switch (max_shared_mem>(block_dim)) { \ + arg_reduce_case(OP, ARG(T), 1024); \ + arg_reduce_case(OP, ARG(T), 512); \ + arg_reduce_case(OP, ARG(T), 256); \ + arg_reduce_case(OP, ARG(T), 128); \ + arg_reduce_case(OP, ARG(T), 64); \ + arg_reduce_case(OP, ARG(T), 32); \ + arg_reduce_case(OP, ARG(T), 16); \ + arg_reduce_case(OP, ARG(T), 8); \ + arg_reduce_case(OP, ARG(T), 4); \ + arg_reduce_case(OP, ARG(T), 2); \ + arg_reduce_case(OP, ARG(T), 1); \ + } \ +} + + +#define impl_arg_reduce(OP, NAME, T) \ +impl_arg_reduce_inner(OP, NAME, T) \ +impl_arg_reduce_strided(OP, NAME, T) \ + +// Contains the intermediate results for the online softmax calculation. +// m: max +// d: sum of the exponentials +template +struct MD { + T m; + float d; + + constexpr MD() = default; + constexpr MD() threadgroup = default; +}; + +// Enable operations for softmax MD +template +struct operation> { + OP op; + + METAL_FUNC MD operator()(MD a, MD b) { + return op(a, b); } + + METAL_FUNC MD operator()(MD a, T b) { + return this->operator()(a, MD{ b, static_cast(1.0) }); + } +}; + +template +METAL_FUNC MD simd_shuffle_down(MD md, ushort delta) { + return MD { + simd_shuffle_down(md.m, delta), + simd_shuffle_down(md.d, delta) + }; +} + +// Enable simd_shuffle_down for softmax MD +template +struct is_valid_simd_type, typename metal::enable_if_t>> { + static constant constexpr bool value = true; +}; + +template +struct MDReduceOp { + Exp fast_exp; + + static constexpr METAL_FUNC MD init() { + return MD{ numeric_limits::lowest(), 0 }; + } + + METAL_FUNC MD operator()(MD a, MD b) { + bool a_bigger = a.m > b.m; + MD bigger_m = a_bigger ? a : b; + MD smaller_m = a_bigger ? b : a; + MD res; + res.d = bigger_m.d + smaller_m.d * fast_exp(smaller_m.m - bigger_m.m); + res.m = bigger_m.m; + return res; + } +}; + + +template +struct finalize_softmax { + Divide fast_divide; + Exp fast_exp; + + METAL_FUNC void operator()( + device const T *src, + device T *dst, + threadgroup MD &md_total, + const uint thread_id, + const uint stop_idx + ) { + const float d_total_inverse = fast_divide(1.0, md_total.d); + for (uint idx = thread_id; idx < stop_idx; idx += BLOCKSIZE) { + dst[idx] = static_cast(fast_exp(src[idx] - md_total.m) * d_total_inverse); + } + } +}; + +// Welford's algorithm approach for an online softmax implementation. +// Same as the Online normalizer calculation for softmax: https://arxiv.org/pdf/1805.02867.pdf +template +METAL_FUNC void softmax( + constant uint &src_numel, + constant uint &el_per_block, + device const T *src, + device T *dst, + threadgroup MD shared[BLOCKSIZE], + threadgroup MD &md_total, + + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]] +) { + using MDReduceOp = MDReduceOp; + + loader, MDReduceOp, BLOCKSIZE> load; + block_reducer, MDReduceOp, BLOCKSIZE> reduce(shared); + finalize_softmax softmax_finalize; + + // Calcluate offset for the threadgroup of current thread; + const uint offset = dst_id * el_per_block; + + // Calculate partial result for current thread + MD md_partial = MD { numeric_limits::lowest(), 0 }; + md_partial = load( + md_partial, + src_numel, + el_per_block, + src, + offset, + tid + ); + + // Reduce in shared memory + MD md = reduce(md_partial, tid); + + if (tid == 0) md_total = md; + threadgroup_barrier(mem_flags::mem_none); + + // Finalize softmax + const uint thread_id = tid + offset; + const uint stop_idx = min(el_per_block + offset, src_numel); + softmax_finalize(src, dst, md_total, thread_id, stop_idx); +} + +#define softmax_case(T, N) \ +case N: { \ + threadgroup MD shared[N]; \ + threadgroup MD md_total; \ + softmax( \ + src_numel, \ + el_per_block, \ + src, \ + dst, \ + shared, \ + md_total, \ + tid, \ + dst_id); \ + break; \ +} + +#define impl_softmax(NAME, T) \ +kernel void NAME( \ + constant uint &src_numel, \ + constant uint &el_per_block, \ + device const T *src, \ + device T *dst, \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + switch (max_shared_mem(block_dim)) { \ + softmax_case(T, 1024); \ + softmax_case(T, 512); \ + softmax_case(T, 256); \ + softmax_case(T, 128); \ + softmax_case(T, 64); \ + softmax_case(T, 32); \ + softmax_case(T, 16); \ + softmax_case(T, 8); \ + softmax_case(T, 4); \ + softmax_case(T, 2); \ + softmax_case(T, 1); \ + } \ } -#define SOFTMAX(NAME, T) \ -kernel void NAME( \ - constant size_t &src_numel, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device T *dst, \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup float shared_memory[THREADGROUP_SIZE]; \ - shared_memory[tid] = -INFINITY; \ - softmax(src_numel, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory); \ -} \ template METAL_FUNC void rmsnorm( @@ -412,6 +1054,8 @@ METAL_FUNC void layernorm( } } +constant int THREADGROUP_SIZE = 2048; + #define RMSNORM(NAME, T) \ kernel void NAME( \ constant size_t &src_numel, \ @@ -561,32 +1205,6 @@ kernel void FN_NAME_THD( \ rope_thd(b, t, h, d, src, cos, sin, dst, idx); \ }\ -REDUCE(x + y, fast_sum_f32_strided, float, 0) -REDUCE(x + y, fast_sum_u32_strided, uint, 0) -REDUCE(x + y, fast_sum_f16_strided, half, 0) -REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0) -REDUCE(x * y, fast_mul_f32_strided, float, 1) -REDUCE(x * y, fast_mul_u32_strided, uint, 1) -REDUCE(x * y, fast_mul_f16_strided, half, 1) -REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF) -REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0) -REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH) -REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0) -REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF) -REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF) -REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH) -REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF) -ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF) -ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH) -ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF) -ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF) -ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF) -ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH) -ARGMAX(fast_argmax_u32_strided, uint, 0) -ARGMAX(fast_argmax_u8_strided, uint8_t, 0) - -SOFTMAX(softmax_f32, float) -SOFTMAX(softmax_f16, half) RMSNORM(rmsnorm_f32, float) RMSNORM(rmsnorm_f16, half) LAYERNORM(layernorm_f32, float) @@ -594,26 +1212,60 @@ LAYERNORM(layernorm_f16, half) ROPE(rope_f32, rope_i_f32, rope_thd_f32, float) ROPE(rope_f16, rope_i_f16, rope_thd_f16, half) +impl_reduce(Sum, fast_sum_f32, float) +impl_reduce(Sum, fast_sum_u32, uint) +impl_reduce(Sum, fast_sum_f16, half) +impl_reduce(Sum, fast_sum_u8, uint8_t) + +impl_reduce(Mul, fast_mul_f32, float) +impl_reduce(Mul, fast_mul_u32, uint) +impl_reduce(Mul, fast_mul_f16, half) +impl_reduce(Mul, fast_mul_u8, uint8_t) + +impl_reduce(Max, fast_max_f32, float) +impl_reduce(Max, fast_max_u32, uint) +impl_reduce(Max, fast_max_f16, half) +impl_reduce(Max, fast_max_u8, uint8_t) + +impl_reduce(Min, fast_min_f32, float) +impl_reduce(Min, fast_min_u32, uint) +impl_reduce(Min, fast_min_f16, half) +impl_reduce(Min, fast_min_u8, uint8_t) + +impl_arg_reduce(Min, fast_argmin_f32, float) +impl_arg_reduce(Min, fast_argmin_f16, half) +impl_arg_reduce(Min, fast_argmin_u32, uint) +impl_arg_reduce(Min, fast_argmin_u8, uint8_t) + +impl_arg_reduce(Max, fast_argmax_f32, float) +impl_arg_reduce(Max, fast_argmax_f16, half) +impl_arg_reduce(Max, fast_argmax_u32, uint) +impl_arg_reduce(Max, fast_argmax_u8, uint8_t) + +impl_softmax(softmax_f32, float) +impl_softmax(softmax_f16, half) + #if __METAL_VERSION__ >= 220 -REDUCE(x + y, fast_sum_i64_strided, int64_t, 0) -REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX) -REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN) -ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) -ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) +impl_reduce(Sum, fast_sum_i64, int64_t) +impl_reduce(Mul, fast_mul_i64, int64_t) +impl_reduce(Min, fast_min_i64, int64_t) +impl_reduce(Max, fast_max_i64, int64_t) + +impl_arg_reduce(Min, fast_argmin_i64, int64_t) +impl_arg_reduce(Max, fast_argmax_i64, int64_t) #endif #if defined(__HAVE_BFLOAT__) -REDUCE(x + y, fast_sum_bf16, bfloat, 0) -REDUCE(x + y, fast_sum_bf16_strided, half, 0) -REDUCE(x * y, fast_mul_bf16, bfloat, 1) -REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1) -REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) -REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF) -REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF) -REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF) -ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) -ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) -SOFTMAX(softmax_bf16, bfloat) +impl_reduce(Sum, fast_sum_bf16, bfloat) +impl_reduce(Mul, fast_mul_bf16, bfloat) +impl_reduce(Max, fast_max_bf16, bfloat) +impl_reduce(Min, fast_min_bf16, bfloat) + +impl_arg_reduce(Min, fast_argmin_bf16, bfloat) +impl_arg_reduce(Max, fast_argmax_bf16, bfloat) + +impl_softmax(softmax_bf16, bfloat) + RMSNORM(rmsnorm_bf16, bfloat) LAYERNORM(layernorm_bf16, bfloat) ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 546680d4e5..21ade21c4c 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,6 +1,8 @@ use super::*; use half::{bf16, f16}; -use metal::MTLResourceOptions; +use metal::{Buffer, Device, MTLResourceOptions}; +use rand::prelude::SliceRandom; +use rand::thread_rng; use rand::Rng; fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { @@ -860,7 +862,12 @@ fn cos_f16() { assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]); } -fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { +fn run_reduce( + v: &[T], + in_length: usize, + out_length: usize, + name: &'static str, +) -> Vec { let device = device(); let kernels = Kernels::new(); let command_queue = device.new_command_queue(); @@ -868,21 +875,24 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec()) as u64, options); - let dims = vec![v.len()]; - let strides = vec![1]; - call_reduce_strided( + let output = device.new_buffer((out_length * core::mem::size_of::()) as u64, options); + let shape = vec![in_length]; + match call_reduce_contiguous( &device, command_buffer, &kernels, name, - &dims, - &strides, + &shape, out_length, BufferOffset::zero_offset(&input), &output, - ) - .unwrap(); + ) { + Ok(_) => {} + Err(e) => { + println!("{e}"); + panic!(); + } + } command_buffer.commit(); command_buffer.wait_until_completed(); @@ -914,22 +924,187 @@ fn run_softmax(v: &[T], last_dim: usize, name: &'sta read_to_vec(&output, v.len()) } -#[test] -fn reduce_sum() { - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let out_length = 1; +const fn create_array() -> [f32; N] { + let mut array: [f32; N] = [0.0; N]; + let mut i = 1; + while i <= N { + array[i - 1] = i as f32; + i += 1; + } + array +} + +const fn correct_sum() -> [f32; D] { + let mut sum = 0; + let mut results: [f32; D] = [0.0; D]; + let mut i = 1; + let mut j = 1; + while i <= N { + sum += i; + i += 1; + if i > j * N / D { + results[j - 1] = sum as f32; + j += 1; + sum = 0; + } + } + results +} + +const fn correct_max() -> [f32; D] { + let mut results: [f32; D] = [0.0; D]; + let mut i = 1; + let mut j = 1; + while i <= N { + i += 1; + if i > j * (N / D) { + results[j - 1] = (i - 1) as f32; + j += 1; + } + } + results +} + +fn correct_argmax(arr: [f32; N]) -> [u32; D] { + let mut max = 0.0; + let mut max_index: u32 = 0; + let mut results: [u32; D] = [0; D]; + let mut i = 0; + let mut j = 1; + while i <= N { + if i >= (j * N / D) { + results[j - 1] = max_index; + max = 0.0; + max_index = 0; + j += 1; + } + if i == N { + break; + } + if arr[i] > max { + max = arr[i]; + max_index = i as u32; + } + i += 1; + } + results +} + +fn reduce_sum_case() { + let mut v = create_array::(); + if D == 1 { + // Hardens 1-dimensional test cases + v.shuffle(&mut thread_rng()); + } + let results = run_reduce(&v, N, D, "fast_sum_f32"); + assert_eq!(approx(results, 4), correct_sum::()); +} + +fn reduce_max_case() { + let mut v = create_array::(); + if D == 1 { + // Hardens 1-dimensional test cases + v.shuffle(&mut thread_rng()); + } + let results = run_reduce(&v, N, D, "fast_max_f32"); + assert_eq!(approx(results, 4), correct_max::()); +} + +fn reduce_argmax_case() { + let mut v = create_array::(); + if D == 1 { + // Hardens 1-dimensional test cases + v.shuffle(&mut thread_rng()); + } + let results: Vec = run_reduce(&v, N, D, "fast_argmax_f32"); + assert_eq!(results, correct_argmax::(v)); +} - let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); - assert_eq!(approx(results, 4), vec![21.0]); +#[test] +fn reduce_sum1() { + reduce_sum_case::<9, 1>(); + reduce_sum_case::<6, 1>(); + reduce_sum_case::<10, 1>(); + reduce_sum_case::<64, 1>(); + reduce_sum_case::<128, 1>(); + reduce_sum_case::<256, 1>(); + reduce_sum_case::<512, 1>(); + reduce_sum_case::<1024, 1>(); + reduce_sum_case::<2048, 1>(); + reduce_sum_case::<4096, 1>(); } #[test] fn reduce_sum2() { - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let out_length = 2; + reduce_sum_case::<6, 2>(); + reduce_sum_case::<10, 2>(); + reduce_sum_case::<64, 2>(); + reduce_sum_case::<128, 2>(); + reduce_sum_case::<256, 2>(); + reduce_sum_case::<512, 2>(); + reduce_sum_case::<1024, 2>(); + reduce_sum_case::<2048, 2>(); + reduce_sum_case::<4096, 2>(); +} + +#[test] +fn reduce_max() { + reduce_max_case::<6, 1>(); + reduce_max_case::<9, 1>(); + reduce_max_case::<10, 1>(); + reduce_max_case::<64, 1>(); + reduce_max_case::<128, 1>(); + reduce_max_case::<256, 1>(); + reduce_max_case::<512, 1>(); + reduce_max_case::<1024, 1>(); + reduce_max_case::<2048, 1>(); + reduce_max_case::<4096, 1>(); + + reduce_max_case::<6, 2>(); + reduce_max_case::<10, 2>(); + reduce_max_case::<64, 2>(); + reduce_max_case::<128, 2>(); + reduce_max_case::<256, 2>(); + reduce_max_case::<512, 2>(); + reduce_max_case::<1024, 2>(); + reduce_max_case::<2048, 2>(); + reduce_max_case::<4096, 2>(); + + reduce_max_case::<6, 3>(); + reduce_max_case::<10, 3>(); + reduce_max_case::<64, 3>(); + reduce_max_case::<128, 3>(); + reduce_max_case::<256, 3>(); + reduce_max_case::<512, 3>(); + reduce_max_case::<1024, 3>(); + reduce_max_case::<2048, 3>(); + reduce_max_case::<4096, 3>(); +} - let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); - assert_eq!(approx(results, 4), vec![6.0, 15.0]); +#[test] +fn reduce_argmax() { + reduce_argmax_case::<6, 1>(); + reduce_argmax_case::<9, 1>(); + reduce_argmax_case::<10, 1>(); + reduce_argmax_case::<64, 1>(); + reduce_argmax_case::<128, 1>(); + reduce_argmax_case::<256, 1>(); + reduce_argmax_case::<512, 1>(); + reduce_argmax_case::<1024, 1>(); + reduce_argmax_case::<2048, 1>(); +} + +#[test] +fn reduce_argmax2() { + reduce_argmax_case::<6, 2>(); + reduce_argmax_case::<10, 2>(); + reduce_argmax_case::<64, 2>(); + reduce_argmax_case::<128, 2>(); + reduce_argmax_case::<256, 2>(); + reduce_argmax_case::<512, 2>(); + reduce_argmax_case::<1024, 2>(); + reduce_argmax_case::<2048, 2>(); + reduce_argmax_case::<4096, 2>(); } #[test] @@ -983,7 +1158,7 @@ fn softmax() { let results = run_softmax(&v, last_dim, "softmax_f16"); assert_eq!( approx_f16(results, 4), - vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338] + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2332, 0.6338] ); let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] diff --git a/candle-metal-kernels/src/utils.metal b/candle-metal-kernels/src/utils.metal new file mode 100644 index 0000000000..8ee6b4ad76 --- /dev/null +++ b/candle-metal-kernels/src/utils.metal @@ -0,0 +1,47 @@ +#pragma once +#include +using namespace metal; + +METAL_FUNC uint nonzero(uint n) { + return n == 0 ? 1 : n; +} + +template +constexpr uint nonzero() { + return N == 0 ? 1 : N; +} + +template +constexpr ushort granularity() { + return nonzero::value>(); +} + +METAL_FUNC uint next_p2(uint x) { + return 1 << (32 - clz(x - 1)); +} + +METAL_FUNC uint prev_p2(uint x) { + return 1 << (31 - clz(x)); +} + +constant uint MAX_SHARED_MEM = 32767; + +template +METAL_FUNC uint max_shared_mem(uint n) { + return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T))); +} + +METAL_FUNC uint get_strided_index( + uint idx, + constant const uint &num_dims, + constant const size_t *dims, + constant const size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} diff --git a/candle-nn/benches/bench_main.rs b/candle-nn/benches/bench_main.rs index 4db1d35c0a..64d9b8b46e 100644 --- a/candle-nn/benches/bench_main.rs +++ b/candle-nn/benches/bench_main.rs @@ -1,4 +1,8 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::layer_norm::benches, benchmarks::conv::benches); +criterion_main!( + benchmarks::softmax::benches, + benchmarks::layer_norm::benches, + benchmarks::conv::benches +); diff --git a/candle-nn/benches/benchmarks/mod.rs b/candle-nn/benches/benchmarks/mod.rs index 30a6ab6a2b..a34d888439 100644 --- a/candle-nn/benches/benchmarks/mod.rs +++ b/candle-nn/benches/benchmarks/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod conv; pub(crate) mod layer_norm; +pub(crate) mod softmax; use candle::{Device, Result}; diff --git a/candle-nn/benches/benchmarks/softmax.rs b/candle-nn/benches/benchmarks/softmax.rs new file mode 100644 index 0000000000..2a1ea2d547 --- /dev/null +++ b/candle-nn/benches/benchmarks/softmax.rs @@ -0,0 +1,49 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle::{DType, Device, Tensor}; +use candle_nn::ops::softmax_last_dim; +use criterion::Throughput; +use criterion::{black_box, criterion_group, Criterion}; +use std::time::Instant; + +fn run(input: &Tensor) { + let _ = softmax_last_dim(&input).unwrap(); +} + +const B: usize = 1; +const M: usize = 1024; +const K: usize = 1024; + +fn run_softmax_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let elements = B * M * K; + + let input = Tensor::rand(-1000.0f32, 1000.0f32, (B, M, K), &device) + .unwrap() + .to_dtype(dtype) + .unwrap(); + + let flops = elements * dtype.size_in_bytes(); + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&input)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let device = BenchDeviceHandler::new().unwrap(); + for d in device.devices { + run_softmax_benchmark(c, &d, DType::F32, "softmax_f32"); + run_softmax_benchmark(c, &d, DType::BF16, "softmax_bf16"); + run_softmax_benchmark(c, &d, DType::F16, "softmax_f16"); + } +} + +criterion_group!(benches, criterion_benchmark);