From 90c74e199cb241377ceef4ba1e7d1152d15d414c Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 29 Dec 2023 11:38:13 +0100 Subject: [PATCH 1/9] Add metal fill kernel --- candle-metal-kernels/src/fill.metal | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 candle-metal-kernels/src/fill.metal diff --git a/candle-metal-kernels/src/fill.metal b/candle-metal-kernels/src/fill.metal new file mode 100644 index 0000000000..9d09980e79 --- /dev/null +++ b/candle-metal-kernels/src/fill.metal @@ -0,0 +1,31 @@ +#include +using namespace metal; + +template +void fill( + device T *buffer [[buffer(0)]], + constant T &value, + constant size_t &numel, + uint gid [[thread_position_in_grid]] +) { + if (gid >= numel) return; + buffer[gid] = value; +} + +#define FILL_OP(T, FN_NAME) \ +kernel void FN_NAME( \ + device T *buffer [[buffer(0)]], \ + constant T &value, \ + constant size_t &numel, \ + uint gid [[thread_position_in_grid]] \ +) { fill(buffer, value, numel, gid); } \ + +FILL_OP(uint8_t, fill_u8) +FILL_OP(uint32_t, fill_u32) +FILL_OP(int64_t, fill_i64) +FILL_OP(half, fill_f16) +FILL_OP(float, fill_f32) + +#if __METAL_VERSION__ >= 310 +FILL_OP(bfloat, fill_bf16) +#endif From fd9bf3bcdd8294d2e4d8b289f79723543dde3ebf Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 29 Dec 2023 11:39:49 +0100 Subject: [PATCH 2/9] remove stray # --- candle-metal-kernels/src/unary.metal | 1 - 1 file changed, 1 deletion(-) diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 04fa37a98d..2a88598c89 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -1,6 +1,5 @@ #include #include -# using namespace metal; METAL_FUNC uint get_strided_index( From 0a29d2e9b85c3cb7aff4fff47ed94146dcf8cdb1 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 29 Dec 2023 12:27:12 +0100 Subject: [PATCH 3/9] Add fill kernel handler --- candle-metal-kernels/Cargo.toml | 2 +- candle-metal-kernels/src/lib.rs | 37 +++++++++++++++++++++++- candle-metal-kernels/src/tests.rs | 47 ++++++++++++++++++++++++++++++- 3 files changed, 83 insertions(+), 3 deletions(-) diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 441d2e88b9..ba09ffcbd1 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -14,7 +14,7 @@ metal = { version = "0.27.0", features = ["mps"]} once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" +half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } [dev-dependencies] -half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } rand = "0.8.5" diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index dd97a86d69..a27306324f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,3 +1,4 @@ +use half::{bf16, f16}; use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, @@ -12,6 +13,7 @@ const UNARY: &str = include_str!("unary.metal"); const BINARY: &str = include_str!("binary.metal"); const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); +const FILL: &str = include_str!("fill.metal"); const REDUCE: &str = include_str!("reduce.metal"); const CONV: &str = include_str!("conv.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); @@ -45,7 +47,7 @@ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, /// Helper functions to create the various objects on the compute command encoder /// on a single line. /// Prevents getting wrong some arguments number and mixing length and size in bytes. -trait EncoderParam { +pub trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } macro_rules! primitive { @@ -62,7 +64,11 @@ macro_rules! primitive { }; } primitive!(usize); +primitive!(u8); primitive!(u32); +primitive!(i64); +primitive!(f16); +primitive!(bf16); primitive!(f32); impl EncoderParam for &[T] { @@ -117,6 +123,7 @@ pub enum Source { Reduce, Mfa, Conv, + Fill, } macro_rules! ops{ @@ -227,6 +234,7 @@ impl Kernels { Source::Indexing => INDEXING, Source::Cast => CAST, Source::Reduce => REDUCE, + Source::Fill => FILL, Source::Conv => CONV, Source::Mfa => panic!("Invalid lib"), } @@ -1562,9 +1570,36 @@ pub fn call_upsample_nearest_2d( Ok(()) } +#[inline] fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } +pub fn call_fill( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + elem_count: usize, + buffer: &Buffer, + value: D, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Fill, kernel_name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger); + + set_params!(encoder, (buffer, value, elem_count)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.use_resource(buffer, metal::MTLResourceUsage::Write); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index c955abca21..c1c7b8ab03 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -11,7 +11,7 @@ fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { fn new_buffer(device: &Device, data: &[T]) -> Buffer { let options = MTLResourceOptions::StorageModeManaged; - let ptr = data.as_ptr() as *const core::ffi::c_void; + let ptr = data.as_ptr() as *const c_void; let size = (data.len() * std::mem::size_of::()) as u64; device.new_buffer_with_data(ptr, size, options) } @@ -806,3 +806,48 @@ fn gemm() { vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] ); } + +fn run_fill( + elem_count: usize, + value: T, + kernel_name: &'static str, +) -> Vec { + let device = device(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let buffer = new_buffer(&device, &vec![0.0f32; elem_count]); + call_fill( + &device, + command_buffer, + &kernels, + kernel_name, + elem_count, + &buffer, + value, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&buffer, elem_count) +} + +#[test] +fn fill() { + fn assert_fill( + value: T, + name: &'static str, + ) { + for i in 0..4 { + assert_eq!(run_fill(8 ^ i, value, name), vec![value; 8 ^ i]); + } + } + assert_fill(123u8, "fill_u8"); + assert_fill(456u32, "fill_u32"); + assert_fill(789i64, "fill_i64"); + assert_fill(f16::from_f32(1.23), "fill_f16"); + assert_fill(bf16::from_f32(4.56), "fill_bf16"); + assert_fill(7.89f32, "fill_f32"); +} From 7fc26764b6fb95429c9130c6122a2f4c2038d1ba Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 29 Dec 2023 16:02:29 +0100 Subject: [PATCH 4/9] Implement generic fill. u8 uses speedy blit encoder --- candle-core/src/metal_backend.rs | 58 +++++++++++++++++------- candle-metal-kernels/Cargo.toml | 1 + candle-metal-kernels/src/lib.rs | 75 ++++++++++++++++++++++++++++++- candle-metal-kernels/src/tests.rs | 35 +++++++-------- 4 files changed, 134 insertions(+), 35 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 6d8afab192..f2b55d4e92 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -3,7 +3,8 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvT use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; -use candle_metal_kernels::Kernels; +use candle_metal_kernels::{CallFill, Fill, Kernels}; +use half::{bf16, f16}; use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; @@ -1403,25 +1404,52 @@ impl BackendDevice for MetalDevice { let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?; let command_buffer = self.command_buffer()?; command_buffer.set_label("zeros"); - let blit = command_buffer.new_blit_command_encoder(); - blit.wait_for_fence(&self.fence); - blit.fill_buffer( + + // This assumes the specific zero type DType is equal to 0x00u8 + // (which is true for all current types) + Fill::call_fill( + &self.device, + &command_buffer, + &self.kernels, + shape.elem_count(), &buffer, - metal::NSRange { - location: 0, - length: buffer.length(), - }, - 0, - ); - blit.update_fence(&self.fence); - blit.end_encoding(); + 0u8, + ) + .map_err(MetalError::from)?; + Ok(MetalStorage::new(buffer, self.clone(), dtype)) } fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { - // TODO Is there a faster way ? - let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; - self.storage_from_cpu_storage(&cpu_storage) + let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?; + let command_buffer = self.command_buffer()?; + command_buffer.set_label("ones"); + + macro_rules! fill { + ($value:expr) => { + Fill::call_fill( + &self.device, + &command_buffer, + &self.kernels, + shape.elem_count(), + &buffer, + $value, + ) + .map_err(MetalError::from)? + }; + } + match dtype { + DType::U8 => fill!(1u8), + DType::U32 => fill!(1u32), + DType::I64 => fill!(1i64), + DType::BF16 => fill!(bf16::ONE), + DType::F16 => fill!(f16::ONE), + DType::F32 => fill!(1f32), + DType::F64 => { + return Err(MetalError::Message(format!("metal doesn't support double")).into()) + } + } + Ok(MetalStorage::new(buffer, self.clone(), dtype)) } fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index ba09ffcbd1..6c64a8e5a5 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -15,6 +15,7 @@ once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } +num-traits = "0.2.17" [dev-dependencies] rand = "0.8.5" diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a27306324f..e0985d9497 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -5,6 +5,7 @@ use metal::{ }; use std::collections::HashMap; use std::ffi::c_void; +use std::marker::PhantomData; use std::sync::RwLock; const AFFINE: &str = include_str!("affine.metal"); @@ -180,6 +181,8 @@ pub mod binary { #[derive(thiserror::Error, Debug)] pub enum MetalKernelError { + #[error("Invalid usage of kernel: {0}")] + InvalidUsage(String), #[error("Could not lock kernel map: {0}")] LockError(String), #[error("Error while loading library: {0}")] @@ -1575,7 +1578,77 @@ fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } -pub fn call_fill( +pub struct Fill { + _marker: PhantomData, +} + +pub trait CallFill { + const KERNEL_NAME: &'static str; + + fn call_fill( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + elem_count: usize, + buffer: &Buffer, + value: T, + ) -> Result<(), MetalKernelError>; +} + +macro_rules ! impl_call_fill { + ($($t:ty),*) => { + $( + impl CallFill<$t> for Fill<$t> { + const KERNEL_NAME: &'static str = concat!("fill_", stringify!($t)); + + fn call_fill(device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, elem_count: usize, buffer: &Buffer, value: $t) -> Result<(), MetalKernelError> { + _call_fill(device, command_buffer, kernels, Self::KERNEL_NAME, elem_count, buffer, value) + } + } + )* + }; +} +impl_call_fill!(u32, i64, f16, bf16, f32); + +impl CallFill for Fill { + const KERNEL_NAME: &'static str = ""; + + fn call_fill( + _: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + elem_count: usize, + buffer: &Buffer, + value: u8, + ) -> Result<(), MetalKernelError> { + _call_blit_fill(command_buffer, kernels, elem_count, buffer, value) + } +} + +fn _call_blit_fill( + command_buffer: &CommandBufferRef, + kernels: &Kernels, + elem_count: usize, + buffer: &Buffer, + value: u8, +) -> Result<(), MetalKernelError> { + let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&kernels.fence); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: elem_count as NSUInteger, + }, + value, + ); + blit.update_fence(&kernels.fence); + blit.end_encoding(); + + Ok(()) +} + +fn _call_fill( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index c1c7b8ab03..a4fb726fd1 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -590,7 +590,6 @@ fn softmax() { } let results = run_softmax(&v, last_dim, "softmax_f32"); let results = approx(results, 4); - println!("{results:?}"); assert_eq!( results.iter().map(|&s| s.round() as usize).sum::(), n @@ -807,22 +806,20 @@ fn gemm() { ); } -fn run_fill( - elem_count: usize, - value: T, - kernel_name: &'static str, -) -> Vec { +fn run_fill(elem_count: usize, value: T) -> Vec +where + Fill: CallFill, +{ let device = device(); let fence = device.new_fence(); let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let buffer = new_buffer(&device, &vec![0.0f32; elem_count]); - call_fill( + Fill::::call_fill( &device, command_buffer, &kernels, - kernel_name, elem_count, &buffer, value, @@ -836,18 +833,18 @@ fn run_fill( #[test] fn fill() { - fn assert_fill( - value: T, - name: &'static str, - ) { + fn assert_fill(value: T) + where + Fill: CallFill, + { for i in 0..4 { - assert_eq!(run_fill(8 ^ i, value, name), vec![value; 8 ^ i]); + assert_eq!(run_fill(8 ^ i, value), vec![value; 8 ^ i]); } } - assert_fill(123u8, "fill_u8"); - assert_fill(456u32, "fill_u32"); - assert_fill(789i64, "fill_i64"); - assert_fill(f16::from_f32(1.23), "fill_f16"); - assert_fill(bf16::from_f32(4.56), "fill_bf16"); - assert_fill(7.89f32, "fill_f32"); + assert_fill(123u8); + assert_fill(456u32); + assert_fill(789i64); + assert_fill(f16::from_f32(1.23)); + assert_fill(bf16::from_f32(4.56)); + assert_fill(7.89f32); } From 6eb44d1bcef00052e2a56424af82bd14d65f2df8 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 1 Jan 2024 20:22:44 +0100 Subject: [PATCH 5/9] Added fill bench --- candle-core/Cargo.toml | 4 ++ candle-core/benches/fill.rs | 57 ++++++++++++++++ candle-core/src/metal_backend.rs | 10 +-- candle-metal-kernels/Cargo.toml | 5 ++ candle-metal-kernels/src/lib.rs | 106 ++++++++++++------------------ candle-metal-kernels/src/tests.rs | 8 +-- 6 files changed, 118 insertions(+), 72 deletions(-) create mode 100644 candle-core/benches/fill.rs diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 91655f5782..6bd1258963 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -49,3 +49,7 @@ metal = ["dep:metal", "dep:candle-metal-kernels"] name = "matmul" harness = false +[[bench]] +name = "fill" +harness = false + diff --git a/candle-core/benches/fill.rs b/candle-core/benches/fill.rs new file mode 100644 index 0000000000..9bcb47751c --- /dev/null +++ b/candle-core/benches/fill.rs @@ -0,0 +1,57 @@ +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use std::time::Instant; + +fn run(shape: (usize, usize, usize), dtype: DType, device: &Device) { + Tensor::ones(shape, dtype, device).unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let b = 1; + let rows = 4096; + let columns = 4096; + + let flops = b * rows * columns; + + let device1 = Device::new_metal(0).unwrap(); + let device2 = device1.clone(); + + let mut group = c.benchmark_group("fill_metal_u8"); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |bencher| { + bencher.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box((b, rows, columns)), black_box(DType::U8), black_box(&device1)); + } + if let Device::Metal(device) = &device1 { + device.wait_until_completed().unwrap(); + } else { + panic!("Expected metal device"); + } + start.elapsed() + }) + }); + group.finish(); + + let mut group = c.benchmark_group("fill_metal_f32"); + group.throughput(Throughput::Bytes((flops * DType::F32.size_in_bytes()) as u64)); + group.bench_function("iter", move |bencher| { + bencher.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box((b, rows, columns)), black_box(DType::F32), black_box(&device2)); + } + if let Device::Metal(device) = &device2 { + device.wait_until_completed().unwrap(); + } else { + panic!("Expected metal device"); + } + start.elapsed() + }) + }); + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index f2b55d4e92..21eb13365a 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -3,7 +3,7 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvT use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; -use candle_metal_kernels::{CallFill, Fill, Kernels}; +use candle_metal_kernels::{FillOp, Unary, Kernels}; use half::{bf16, f16}; use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; @@ -1405,9 +1405,9 @@ impl BackendDevice for MetalDevice { let command_buffer = self.command_buffer()?; command_buffer.set_label("zeros"); - // This assumes the specific zero type DType is equal to 0x00u8 + // This assumes the zero value of this DType is equal to 0x00u8 // (which is true for all current types) - Fill::call_fill( + Unary::fill( &self.device, &command_buffer, &self.kernels, @@ -1421,13 +1421,13 @@ impl BackendDevice for MetalDevice { } fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { - let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?; + let buffer = self.new_buffer(shape.elem_count(), dtype, "ones")?; let command_buffer = self.command_buffer()?; command_buffer.set_label("ones"); macro_rules! fill { ($value:expr) => { - Fill::call_fill( + Unary::fill( &self.device, &command_buffer, &self.kernels, diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 6c64a8e5a5..25446d290a 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -19,3 +19,8 @@ num-traits = "0.2.17" [dev-dependencies] rand = "0.8.5" +criterion = "0.5.1" + +[[bench]] +name = "fill" +harness = false diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index e0985d9497..f5b0653bf4 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1573,19 +1573,19 @@ pub fn call_upsample_nearest_2d( Ok(()) } -#[inline] +#[inline(always)] fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } -pub struct Fill { +pub struct Unary { _marker: PhantomData, } -pub trait CallFill { - const KERNEL_NAME: &'static str; +pub trait FillOp { + const FILL_KERNEL: &'static str; - fn call_fill( + fn fill( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, @@ -1598,11 +1598,26 @@ pub trait CallFill { macro_rules ! impl_call_fill { ($($t:ty),*) => { $( - impl CallFill<$t> for Fill<$t> { - const KERNEL_NAME: &'static str = concat!("fill_", stringify!($t)); - - fn call_fill(device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, elem_count: usize, buffer: &Buffer, value: $t) -> Result<(), MetalKernelError> { - _call_fill(device, command_buffer, kernels, Self::KERNEL_NAME, elem_count, buffer, value) + impl FillOp<$t> for Unary<$t> { + const FILL_KERNEL: &'static str = concat!("fill_", stringify!($t)); + + #[inline(always)] + fn fill(device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, elem_count: usize, buffer: &Buffer, value: $t) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Fill, Self::FILL_KERNEL)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger); + + set_params!(encoder, (buffer, value, elem_count)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.use_resource(buffer, metal::MTLResourceUsage::Write); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) } } )* @@ -1610,10 +1625,11 @@ macro_rules ! impl_call_fill { } impl_call_fill!(u32, i64, f16, bf16, f32); -impl CallFill for Fill { - const KERNEL_NAME: &'static str = ""; +impl FillOp for Unary { + const FILL_KERNEL: &'static str = ""; - fn call_fill( + #[inline(always)] + fn fill( _: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, @@ -1621,57 +1637,21 @@ impl CallFill for Fill { buffer: &Buffer, value: u8, ) -> Result<(), MetalKernelError> { - _call_blit_fill(command_buffer, kernels, elem_count, buffer, value) - } -} - -fn _call_blit_fill( - command_buffer: &CommandBufferRef, - kernels: &Kernels, - elem_count: usize, - buffer: &Buffer, - value: u8, -) -> Result<(), MetalKernelError> { - let blit = command_buffer.new_blit_command_encoder(); - blit.wait_for_fence(&kernels.fence); - blit.fill_buffer( - &buffer, - metal::NSRange { - location: 0, - length: elem_count as NSUInteger, - }, - value, - ); - blit.update_fence(&kernels.fence); - blit.end_encoding(); - - Ok(()) -} - -fn _call_fill( - device: &Device, - command_buffer: &CommandBufferRef, - kernels: &Kernels, - kernel_name: &'static str, - elem_count: usize, - buffer: &Buffer, - value: D, -) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Fill, kernel_name)?; - let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); - encoder.set_compute_pipeline_state(&pipeline); - encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger); - - set_params!(encoder, (buffer, value, elem_count)); - - let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.use_resource(buffer, metal::MTLResourceUsage::Write); - encoder.update_fence(&kernels.fence); - encoder.end_encoding(); + let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&kernels.fence); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: elem_count as NSUInteger, + }, + value, + ); + blit.update_fence(&kernels.fence); + blit.end_encoding(); - Ok(()) + Ok(()) + } } #[cfg(test)] diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index a4fb726fd1..b7bff740f2 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -808,7 +808,7 @@ fn gemm() { fn run_fill(elem_count: usize, value: T) -> Vec where - Fill: CallFill, + Unary: FillOp, { let device = device(); let fence = device.new_fence(); @@ -816,7 +816,7 @@ where let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let buffer = new_buffer(&device, &vec![0.0f32; elem_count]); - Fill::::call_fill( + Unary::::fill( &device, command_buffer, &kernels, @@ -835,7 +835,7 @@ where fn fill() { fn assert_fill(value: T) where - Fill: CallFill, + Unary: FillOp, { for i in 0..4 { assert_eq!(run_fill(8 ^ i, value), vec![value; 8 ^ i]); @@ -847,4 +847,4 @@ fn fill() { assert_fill(f16::from_f32(1.23)); assert_fill(bf16::from_f32(4.56)); assert_fill(7.89f32); -} +} \ No newline at end of file From e8e24f1284decb5b56a1e3f3ff41e49860d01244 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 1 Jan 2024 20:37:56 +0100 Subject: [PATCH 6/9] Follow crate conventions --- candle-core/benches/fill.rs | 16 ++++- candle-core/src/metal_backend.rs | 20 ++++-- candle-metal-kernels/Cargo.toml | 4 -- candle-metal-kernels/src/lib.rs | 110 +++++++++++++----------------- candle-metal-kernels/src/tests.rs | 2 +- 5 files changed, 75 insertions(+), 77 deletions(-) diff --git a/candle-core/benches/fill.rs b/candle-core/benches/fill.rs index 9bcb47751c..9bd0aa728d 100644 --- a/candle-core/benches/fill.rs +++ b/candle-core/benches/fill.rs @@ -22,7 +22,11 @@ fn criterion_benchmark(c: &mut Criterion) { bencher.iter_custom(|iters| { let start = Instant::now(); for _i in 0..iters { - run(black_box((b, rows, columns)), black_box(DType::U8), black_box(&device1)); + run( + black_box((b, rows, columns)), + black_box(DType::U8), + black_box(&device1), + ); } if let Device::Metal(device) = &device1 { device.wait_until_completed().unwrap(); @@ -35,12 +39,18 @@ fn criterion_benchmark(c: &mut Criterion) { group.finish(); let mut group = c.benchmark_group("fill_metal_f32"); - group.throughput(Throughput::Bytes((flops * DType::F32.size_in_bytes()) as u64)); + group.throughput(Throughput::Bytes( + (flops * DType::F32.size_in_bytes()) as u64, + )); group.bench_function("iter", move |bencher| { bencher.iter_custom(|iters| { let start = Instant::now(); for _i in 0..iters { - run(black_box((b, rows, columns)), black_box(DType::F32), black_box(&device2)); + run( + black_box((b, rows, columns)), + black_box(DType::F32), + black_box(&device2), + ); } if let Device::Metal(device) = &device2 { device.wait_until_completed().unwrap(); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 21eb13365a..3f6060ce89 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -3,7 +3,7 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvT use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; -use candle_metal_kernels::{FillOp, Unary, Kernels}; +use candle_metal_kernels::Kernels; use half::{bf16, f16}; use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; @@ -1405,15 +1405,14 @@ impl BackendDevice for MetalDevice { let command_buffer = self.command_buffer()?; command_buffer.set_label("zeros"); - // This assumes the zero value of this DType is equal to 0x00u8 + // This kernel assumes the zero value of this DType is equal to 0x00u8 // (which is true for all current types) - Unary::fill( - &self.device, + candle_metal_kernels::call_fill_u8( &command_buffer, &self.kernels, shape.elem_count(), &buffer, - 0u8, + 0, ) .map_err(MetalError::from)?; @@ -1427,7 +1426,7 @@ impl BackendDevice for MetalDevice { macro_rules! fill { ($value:expr) => { - Unary::fill( + candle_metal_kernels::call_fill( &self.device, &command_buffer, &self.kernels, @@ -1439,7 +1438,14 @@ impl BackendDevice for MetalDevice { }; } match dtype { - DType::U8 => fill!(1u8), + DType::U8 => candle_metal_kernels::call_fill_u8( + &command_buffer, + &self.kernels, + shape.elem_count(), + &buffer, + 1u8, + ) + .map_err(MetalError::from)?, DType::U32 => fill!(1u32), DType::I64 => fill!(1i64), DType::BF16 => fill!(bf16::ONE), diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 25446d290a..162adbd797 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -20,7 +20,3 @@ num-traits = "0.2.17" [dev-dependencies] rand = "0.8.5" criterion = "0.5.1" - -[[bench]] -name = "fill" -harness = false diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index f5b0653bf4..5db4c2cb17 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -5,7 +5,6 @@ use metal::{ }; use std::collections::HashMap; use std::ffi::c_void; -use std::marker::PhantomData; use std::sync::RwLock; const AFFINE: &str = include_str!("affine.metal"); @@ -1578,81 +1577,68 @@ fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } -pub struct Unary { - _marker: PhantomData, +pub fn call_fill( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + elem_count: usize, + buffer: &Buffer, + value: T, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Fill, T::FILL_KERNEL)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger); + + set_params!(encoder, (buffer, value, elem_count)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.use_resource(buffer, metal::MTLResourceUsage::Write); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) } -pub trait FillOp { - const FILL_KERNEL: &'static str; +pub fn call_fill_u8( + command_buffer: &CommandBufferRef, + kernels: &Kernels, + elem_count: usize, + buffer: &Buffer, + value: u8, +) -> Result<(), MetalKernelError> { + let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&kernels.fence); + blit.fill_buffer( + buffer, + metal::NSRange { + location: 0, + length: elem_count as NSUInteger, + }, + value, + ); + blit.update_fence(&kernels.fence); + blit.end_encoding(); - fn fill( - device: &Device, - command_buffer: &CommandBufferRef, - kernels: &Kernels, - elem_count: usize, - buffer: &Buffer, - value: T, - ) -> Result<(), MetalKernelError>; + Ok(()) +} + +pub trait FillOp: EncoderParam { + const FILL_KERNEL: &'static str; } macro_rules ! impl_call_fill { ($($t:ty),*) => { $( - impl FillOp<$t> for Unary<$t> { + impl FillOp for $t { const FILL_KERNEL: &'static str = concat!("fill_", stringify!($t)); - - #[inline(always)] - fn fill(device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, elem_count: usize, buffer: &Buffer, value: $t) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Fill, Self::FILL_KERNEL)?; - let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); - encoder.set_compute_pipeline_state(&pipeline); - encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger); - - set_params!(encoder, (buffer, value, elem_count)); - - let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.use_resource(buffer, metal::MTLResourceUsage::Write); - encoder.update_fence(&kernels.fence); - encoder.end_encoding(); - - Ok(()) - } } )* }; } impl_call_fill!(u32, i64, f16, bf16, f32); -impl FillOp for Unary { - const FILL_KERNEL: &'static str = ""; - - #[inline(always)] - fn fill( - _: &Device, - command_buffer: &CommandBufferRef, - kernels: &Kernels, - elem_count: usize, - buffer: &Buffer, - value: u8, - ) -> Result<(), MetalKernelError> { - let blit = command_buffer.new_blit_command_encoder(); - blit.wait_for_fence(&kernels.fence); - blit.fill_buffer( - &buffer, - metal::NSRange { - location: 0, - length: elem_count as NSUInteger, - }, - value, - ); - blit.update_fence(&kernels.fence); - blit.end_encoding(); - - Ok(()) - } -} - #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index b7bff740f2..4b27d16302 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -847,4 +847,4 @@ fn fill() { assert_fill(f16::from_f32(1.23)); assert_fill(bf16::from_f32(4.56)); assert_fill(7.89f32); -} \ No newline at end of file +} From 45936a18f8970fd1aa9288f550f8efd747394240 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 9 Jan 2024 18:54:48 +0100 Subject: [PATCH 7/9] Update with feature separated benchmarks --- candle-core/Cargo.toml | 5 -- candle-core/benches/bench_main.rs | 2 +- candle-core/benches/benchmarks/fill.rs | 43 +++++++++++++++++ candle-core/benches/benchmarks/mod.rs | 1 + candle-core/benches/fill.rs | 67 -------------------------- 5 files changed, 45 insertions(+), 73 deletions(-) create mode 100644 candle-core/benches/benchmarks/fill.rs delete mode 100644 candle-core/benches/fill.rs diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 1b279999e9..afdb67cd81 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -48,8 +48,3 @@ metal = ["dep:metal", "dep:candle-metal-kernels"] [[bench]] name = "bench_main" harness = false - -[[bench]] -name = "fill" -harness = false - diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 4425f2fb32..6362d8034c 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -1,4 +1,4 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::matmul::benches); +criterion_main!(benchmarks::matmul::benches, benchmarks::fill::benches); diff --git a/candle-core/benches/benchmarks/fill.rs b/candle-core/benches/benchmarks/fill.rs new file mode 100644 index 0000000000..94268aa803 --- /dev/null +++ b/candle-core/benches/benchmarks/fill.rs @@ -0,0 +1,43 @@ +use crate::benchmarks::{bench_name, device, BenchDevice}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run(shape: (usize, usize, usize), dtype: DType, device: &Device) { + Tensor::ones(shape, dtype, device).unwrap(); +} + +fn run_fill_benchmark(c: &mut Criterion, name: &str, dtype: DType) { + let b = 1; + let rows = 4096; + let columns = 4096; + + let flops = b * rows * columns * dtype.size_in_bytes(); + + let device = device().unwrap(); + + let mut group = c.benchmark_group(bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |bencher| { + bencher.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run( + black_box((b, rows, columns)), + black_box(dtype), + black_box(&device), + ); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + run_fill_benchmark(c, "fill_u8", DType::U8); + run_fill_benchmark(c, "fill_f32", DType::F32); +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 1344770dce..9bfbf83aca 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod fill; pub(crate) mod matmul; use candle_core::{Device, Result}; diff --git a/candle-core/benches/fill.rs b/candle-core/benches/fill.rs deleted file mode 100644 index 9bd0aa728d..0000000000 --- a/candle-core/benches/fill.rs +++ /dev/null @@ -1,67 +0,0 @@ -use candle_core::{DType, Device, Tensor}; -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; -use std::time::Instant; - -fn run(shape: (usize, usize, usize), dtype: DType, device: &Device) { - Tensor::ones(shape, dtype, device).unwrap(); -} - -fn criterion_benchmark(c: &mut Criterion) { - let b = 1; - let rows = 4096; - let columns = 4096; - - let flops = b * rows * columns; - - let device1 = Device::new_metal(0).unwrap(); - let device2 = device1.clone(); - - let mut group = c.benchmark_group("fill_metal_u8"); - group.throughput(Throughput::Bytes(flops as u64)); - group.bench_function("iter", move |bencher| { - bencher.iter_custom(|iters| { - let start = Instant::now(); - for _i in 0..iters { - run( - black_box((b, rows, columns)), - black_box(DType::U8), - black_box(&device1), - ); - } - if let Device::Metal(device) = &device1 { - device.wait_until_completed().unwrap(); - } else { - panic!("Expected metal device"); - } - start.elapsed() - }) - }); - group.finish(); - - let mut group = c.benchmark_group("fill_metal_f32"); - group.throughput(Throughput::Bytes( - (flops * DType::F32.size_in_bytes()) as u64, - )); - group.bench_function("iter", move |bencher| { - bencher.iter_custom(|iters| { - let start = Instant::now(); - for _i in 0..iters { - run( - black_box((b, rows, columns)), - black_box(DType::F32), - black_box(&device2), - ); - } - if let Device::Metal(device) = &device2 { - device.wait_until_completed().unwrap(); - } else { - panic!("Expected metal device"); - } - start.elapsed() - }) - }); - group.finish(); -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); From b9ce263e4deb5c1a9e72996a986bf07016604a07 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 18 Jan 2024 12:07:49 +0100 Subject: [PATCH 8/9] Metal version check for fill_i64 --- candle-metal-kernels/src/fill.metal | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/candle-metal-kernels/src/fill.metal b/candle-metal-kernels/src/fill.metal index 9d09980e79..84209954f5 100644 --- a/candle-metal-kernels/src/fill.metal +++ b/candle-metal-kernels/src/fill.metal @@ -22,10 +22,13 @@ kernel void FN_NAME( \ FILL_OP(uint8_t, fill_u8) FILL_OP(uint32_t, fill_u32) -FILL_OP(int64_t, fill_i64) FILL_OP(half, fill_f16) FILL_OP(float, fill_f32) -#if __METAL_VERSION__ >= 310 +#if __METAL_VERSION__ >= 220 +FILL_OP(int64_t, fill_i64) +#endif + +#if defined(__HAVE_BFLOAT__) FILL_OP(bfloat, fill_bf16) #endif From ceaf7f1e2d331272d9c38c1af0bcaee42245333e Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 22 Jan 2024 21:17:20 +0100 Subject: [PATCH 9/9] More concise macros --- candle-metal-kernels/src/lib.rs | 47 +++++++++------------------------ 1 file changed, 13 insertions(+), 34 deletions(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index c06f168ff6..dc9f4f4da8 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -53,23 +53,18 @@ pub trait EncoderParam: private::Sealed { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } -macro_rules! primitive { - ($type:ty) => { - impl EncoderParam for $type { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { - encoder.set_bytes( - position, - core::mem::size_of::<$type>() as u64, - &data as *const $type as *const c_void, - ); - } - } - }; -} macro_rules! primitives { ($($type:ty),+) => { $( - primitive!($type); + impl EncoderParam for $type { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of::<$type>() as u64, + &data as *const $type as *const c_void, + ); + } + } )+ }; } @@ -116,7 +111,7 @@ macro_rules! set_params { ); } -// Seal the trait so that only the types we want can implement it +// Seal for EncoderParam so that only the types we want can implement it mod private { use super::*; @@ -124,27 +119,11 @@ mod private { macro_rules! sealed { ($($type:ty),+) => { - $( - impl Sealed for $type {} - )+ + $(impl Sealed for $type {})+ }; } - sealed!( - usize, - u8, - u32, - u64, - i32, - i64, - f16, - bf16, - f32, - bool, - &Buffer, - (&Buffer, usize), - &mut Buffer, - (&mut Buffer, usize) - ); + sealed!(usize, u8, u32, u64, i32, i64, f16, bf16, f32, bool); + sealed!(&Buffer, (&Buffer, usize), &mut Buffer, (&mut Buffer, usize)); impl Sealed for &[T] {} }