From 3f04a79ada7ca974176a0c7c3c3306f394eae9a9 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sun, 7 Jan 2024 14:40:15 +0100 Subject: [PATCH 1/6] Use cfg to seperate benchmark results based on features --- candle-core/benches/bench_utils.rs | 56 ++++++++++++++++++++++++++++++ candle-core/benches/matmul.rs | 16 ++++----- 2 files changed, 64 insertions(+), 8 deletions(-) create mode 100644 candle-core/benches/bench_utils.rs diff --git a/candle-core/benches/bench_utils.rs b/candle-core/benches/bench_utils.rs new file mode 100644 index 0000000000..75800761f4 --- /dev/null +++ b/candle-core/benches/bench_utils.rs @@ -0,0 +1,56 @@ +use candle_core::{Device, Result}; + +pub(crate) trait BenchDevice { + fn sync(&self) -> Result<()>; +} + +impl BenchDevice for Device { + fn sync(&self) -> Result<()> { + match self { + Device::Cpu => Ok(()), + Device::Cuda(device) => { + #[cfg(feature = "cuda")] + return Ok(device.synchronize()?); + #[cfg(not(feature = "cuda"))] + panic!("Cuda device without cuda feature enabled: {:?}", device) + } + Device::Metal(device) => { + #[cfg(feature = "metal")] + return Ok(device.wait_until_completed()?); + #[cfg(not(feature = "metal"))] + panic!("Metal device without metal feature enabled: {:?}", device) + } + } + } +} + +#[allow(dead_code)] +pub(crate) fn device() -> Result { + return if cfg!(feature = "metal") { + Device::new_metal(0) + } else if cfg!(feature = "cuda") { + Device::new_cuda(0) + } else { + Ok(Device::Cpu) + }; +} + +#[allow(dead_code)] +pub(crate) fn bench_name>(name: S) -> String { + format!("{}_{}", device_variant(), name.into()) +} + +#[allow(dead_code)] +const fn device_variant() -> &'static str { + return if cfg!(feature = "metal") { + "metal" + } else if cfg!(feature = "cuda") { + "cuda" + } else if cfg!(feature = "accelerate") { + "accelerate" + } else if cfg!(feature = "mkl") { + "mkl" + } else { + "cpu" + }; +} diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/matmul.rs index 83679771b2..4f7dfa6c9f 100644 --- a/candle-core/benches/matmul.rs +++ b/candle-core/benches/matmul.rs @@ -1,4 +1,8 @@ -use candle_core::{DType, Device, Tensor}; +mod bench_utils; + +use crate::bench_utils::bench_name; +use bench_utils::{device, BenchDevice}; +use candle_core::{DType, Tensor}; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; use std::time::Instant; @@ -12,14 +16,14 @@ fn criterion_benchmark(c: &mut Criterion) { let n = 2048; let k = 2048; - let device = Device::new_metal(0).unwrap(); + let device = device().unwrap(); let dtype = DType::F32; let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap(); let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap(); let flops = b * m * n * k; - let mut group = c.benchmark_group("matmul_metal"); + let mut group = c.benchmark_group(bench_name("matmul")); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |b| { b.iter_custom(|iters| { @@ -27,11 +31,7 @@ fn criterion_benchmark(c: &mut Criterion) { for _i in 0..iters { run(black_box(&lhs), black_box(&rhs)); } - if let Device::Metal(device) = &device { - device.wait_until_completed().unwrap(); - } else { - panic!("Expected metal device"); - } + device.sync().unwrap(); start.elapsed() }) }); From febb4a57fdee8acd1f94fbd9e14c37ce93dcf10b Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sun, 7 Jan 2024 19:40:30 +0100 Subject: [PATCH 2/6] Add metal where_cond for f16 and bf16. Add benchmark --- candle-core/Cargo.toml | 4 ++ candle-core/benches/where_cond.rs | 66 ++++++++++++++++++++++++++ candle-core/src/metal_backend.rs | 1 + candle-metal-kernels/src/ternary.metal | 66 +++++++++++++++++--------- 4 files changed, 114 insertions(+), 23 deletions(-) create mode 100644 candle-core/benches/where_cond.rs diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 91655f5782..2385a76b15 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -49,3 +49,7 @@ metal = ["dep:metal", "dep:candle-metal-kernels"] name = "matmul" harness = false +[[bench]] +name = "where_cond" +harness = false + diff --git a/candle-core/benches/where_cond.rs b/candle-core/benches/where_cond.rs new file mode 100644 index 0000000000..9dd943e644 --- /dev/null +++ b/candle-core/benches/where_cond.rs @@ -0,0 +1,66 @@ +mod bench_utils; + +use crate::bench_utils::bench_name; +use bench_utils::{device, BenchDevice}; +use candle_core::{DType, Tensor}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor, b: &Tensor, c: &Tensor) { + a.where_cond(b, c).unwrap(); +} + +const fn create_cond_arr() -> [u8; N] { + let mut arr = [0u8; N]; + let mut i = 0; + while i < N { + arr[i] = (i % 2) as u8; + i += 1; + } + arr +} + +const B: usize = 1; +const M: usize = 1024; +const K: usize = 1024; +const SIZE: usize = B * M * K; + +const DATA: [u8; SIZE] = create_cond_arr::(); + +fn run_where_cond_benchmark(c: &mut Criterion, dtype: DType, name: &str) { + let device = device().unwrap(); + let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap(); + let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap(); + let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap(); + + let elements = B * M * K; + // E.g. 2 f32 tensors + 1 u8 tensor + let flops = (2 * elements * dtype.size_in_bytes()) + elements; + + let mut group = c.benchmark_group(bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run( + black_box(&tensor), + black_box(&on_true), + black_box(&on_false), + ); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + run_where_cond_benchmark(c, DType::F32, "where_cond_f32"); + run_where_cond_benchmark(c, DType::BF16, "where_cond_bf16"); + run_where_cond_benchmark(c, DType::F16, "where_cond_f16"); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index c1c4aa4bc0..cc06a1b6cd 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -790,6 +790,7 @@ impl BackendStorage for MetalStorage { } let name = match (self.dtype, t.dtype()) { (DType::U8, DType::F32) => "where_u8_f32", + (DType::U8, DType::BF16) => "where_u8_bf16", (DType::U8, DType::F16) => "where_u8_f16", (DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::U32) => "where_u8_u32", diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 40b4bcf42f..7b3b8ca9eb 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -17,29 +17,45 @@ METAL_FUNC uint get_strided_index( return strided_i; } +template +METAL_FUNC void where_cond( + constant size_t &numel, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant size_t *strides_t, + constant size_t *strides_f, + device const ID *ids, + device const T *t, + device const T *f, + device T *out, + uint i [[ thread_position_in_grid ]] +) { + if (i >= numel){ + return; + } + uint strided_i = get_strided_index(i, num_dims, dims, strides); + uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); + uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); + out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; +} -#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \ -kernel void FN_NAME( \ - constant size_t &numel, \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant size_t *strides_t, \ - constant size_t *strides_f, \ - device const ID_TYPENAME *ids, \ - device const TYPENAME *t, \ - device const TYPENAME *f, \ - device TYPENAME *out ,\ - uint i [[ thread_position_in_grid ]] \ -) { \ - if (i >= numel){ \ - return; \ - } \ - uint strided_i = get_strided_index(i, num_dims, dims, strides); \ - uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ - uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ - out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \ -} \ +#define WHERE_OP(T, ID, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &numel, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t *strides_t, \ + constant size_t *strides_f, \ + device const ID *ids, \ + device const T *t, \ + device const T *f, \ + device T *out, \ + uint i [[ thread_position_in_grid ]] \ +) { \ + where_cond(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \ +} \ // WHERE_OP(float, int64_t, where_i64_f32) // WHERE_OP(double, int64_t, where_i64_f64) @@ -54,10 +70,14 @@ kernel void FN_NAME( \ // WHERE_OP(int64_t, uint32_t, where_u32_i64) WHERE_OP(float, uint8_t, where_u8_f32) -// WHERE_OP(double, uint8_t, where_u8_f64) +WHERE_OP(half, uint8_t, where_u8_f16) WHERE_OP(uint8_t, uint8_t, where_u8_u8) WHERE_OP(uint32_t, uint8_t, where_u8_u32) #if __METAL_VERSION__ >= 220 WHERE_OP(int64_t, uint8_t, where_u8_i64) #endif + +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, uint8_t, where_u8_bf16) +#endif \ No newline at end of file From ad075a5f7edb0114b820b3e99a19b17d0d25ec3b Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 8 Jan 2024 06:48:33 +0100 Subject: [PATCH 3/6] Remove allow pragma --- candle-core/benches/matmul.rs | 5 ++--- candle-core/benches/{bench_utils.rs => utils.rs} | 3 --- 2 files changed, 2 insertions(+), 6 deletions(-) rename candle-core/benches/{bench_utils.rs => utils.rs} (96%) diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/matmul.rs index 4f7dfa6c9f..a5dba9ccdb 100644 --- a/candle-core/benches/matmul.rs +++ b/candle-core/benches/matmul.rs @@ -1,10 +1,9 @@ -mod bench_utils; +mod utils; -use crate::bench_utils::bench_name; -use bench_utils::{device, BenchDevice}; use candle_core::{DType, Tensor}; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; use std::time::Instant; +use utils::{bench_name, device, BenchDevice}; fn run(a: &Tensor, b: &Tensor) { a.matmul(&b.t().unwrap()).unwrap(); diff --git a/candle-core/benches/bench_utils.rs b/candle-core/benches/utils.rs similarity index 96% rename from candle-core/benches/bench_utils.rs rename to candle-core/benches/utils.rs index 75800761f4..a93afc6e76 100644 --- a/candle-core/benches/bench_utils.rs +++ b/candle-core/benches/utils.rs @@ -24,7 +24,6 @@ impl BenchDevice for Device { } } -#[allow(dead_code)] pub(crate) fn device() -> Result { return if cfg!(feature = "metal") { Device::new_metal(0) @@ -35,12 +34,10 @@ pub(crate) fn device() -> Result { }; } -#[allow(dead_code)] pub(crate) fn bench_name>(name: S) -> String { format!("{}_{}", device_variant(), name.into()) } -#[allow(dead_code)] const fn device_variant() -> &'static str { return if cfg!(feature = "metal") { "metal" From fb05af4c42a6343856f167893b19974ed6b50276 Mon Sep 17 00:00:00 2001 From: Laurent Date: Mon, 8 Jan 2024 07:19:59 +0100 Subject: [PATCH 4/6] Avoid some unnecessary returns. --- candle-core/benches/utils.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/candle-core/benches/utils.rs b/candle-core/benches/utils.rs index a93afc6e76..3e8b3c5776 100644 --- a/candle-core/benches/utils.rs +++ b/candle-core/benches/utils.rs @@ -25,13 +25,13 @@ impl BenchDevice for Device { } pub(crate) fn device() -> Result { - return if cfg!(feature = "metal") { + if cfg!(feature = "metal") { Device::new_metal(0) } else if cfg!(feature = "cuda") { Device::new_cuda(0) } else { Ok(Device::Cpu) - }; + } } pub(crate) fn bench_name>(name: S) -> String { @@ -39,7 +39,7 @@ pub(crate) fn bench_name>(name: S) -> String { } const fn device_variant() -> &'static str { - return if cfg!(feature = "metal") { + if cfg!(feature = "metal") { "metal" } else if cfg!(feature = "cuda") { "cuda" @@ -49,5 +49,5 @@ const fn device_variant() -> &'static str { "mkl" } else { "cpu" - }; + } } From 88945f2c227d6a76e8f133d8b72c701b180c4f4c Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 9 Jan 2024 18:31:28 +0100 Subject: [PATCH 5/6] Improve benchmarks layout --- candle-core/Cargo.toml | 2 +- candle-core/benches/bench_main.rs | 4 ++++ candle-core/benches/{ => benchmarks}/matmul.rs | 7 ++----- candle-core/benches/{utils.rs => benchmarks/mod.rs} | 2 ++ 4 files changed, 9 insertions(+), 6 deletions(-) create mode 100644 candle-core/benches/bench_main.rs rename candle-core/benches/{ => benchmarks}/matmul.rs (85%) rename candle-core/benches/{utils.rs => benchmarks/mod.rs} (98%) diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 91655f5782..93b718a3e3 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -46,6 +46,6 @@ accelerate = ["dep:libc", "dep:accelerate-src"] metal = ["dep:metal", "dep:candle-metal-kernels"] [[bench]] -name = "matmul" +name = "bench_main" harness = false diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs new file mode 100644 index 0000000000..4425f2fb32 --- /dev/null +++ b/candle-core/benches/bench_main.rs @@ -0,0 +1,4 @@ +mod benchmarks; + +use criterion::criterion_main; +criterion_main!(benchmarks::matmul::benches); diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/benchmarks/matmul.rs similarity index 85% rename from candle-core/benches/matmul.rs rename to candle-core/benches/benchmarks/matmul.rs index a5dba9ccdb..fb173f0475 100644 --- a/candle-core/benches/matmul.rs +++ b/candle-core/benches/benchmarks/matmul.rs @@ -1,9 +1,7 @@ -mod utils; - +use crate::benchmarks::{bench_name, device, BenchDevice}; use candle_core::{DType, Tensor}; -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; use std::time::Instant; -use utils::{bench_name, device, BenchDevice}; fn run(a: &Tensor, b: &Tensor) { a.matmul(&b.t().unwrap()).unwrap(); @@ -38,4 +36,3 @@ fn criterion_benchmark(c: &mut Criterion) { } criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/candle-core/benches/utils.rs b/candle-core/benches/benchmarks/mod.rs similarity index 98% rename from candle-core/benches/utils.rs rename to candle-core/benches/benchmarks/mod.rs index 3e8b3c5776..1344770dce 100644 --- a/candle-core/benches/utils.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,3 +1,5 @@ +pub(crate) mod matmul; + use candle_core::{Device, Result}; pub(crate) trait BenchDevice { From 349910d5decedced3955eb5da7da90ea01f8066e Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 9 Jan 2024 18:41:31 +0100 Subject: [PATCH 6/6] Updated feature separated benchmarks --- candle-core/Cargo.toml | 5 ----- candle-core/benches/bench_main.rs | 2 +- candle-core/benches/benchmarks/mod.rs | 1 + candle-core/benches/{ => benchmarks}/where_cond.rs | 8 ++------ 4 files changed, 4 insertions(+), 12 deletions(-) rename candle-core/benches/{ => benchmarks}/where_cond.rs (89%) diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index a85ebfb34e..afdb67cd81 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -48,8 +48,3 @@ metal = ["dep:metal", "dep:candle-metal-kernels"] [[bench]] name = "bench_main" harness = false - -[[bench]] -name = "where_cond" -harness = false - diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 4425f2fb32..92c33a8675 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -1,4 +1,4 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::matmul::benches); +criterion_main!(benchmarks::matmul::benches, benchmarks::where_cond::benches); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 1344770dce..ee80adce85 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod matmul; +pub(crate) mod where_cond; use candle_core::{Device, Result}; diff --git a/candle-core/benches/where_cond.rs b/candle-core/benches/benchmarks/where_cond.rs similarity index 89% rename from candle-core/benches/where_cond.rs rename to candle-core/benches/benchmarks/where_cond.rs index 9dd943e644..d76fb75c00 100644 --- a/candle-core/benches/where_cond.rs +++ b/candle-core/benches/benchmarks/where_cond.rs @@ -1,9 +1,6 @@ -mod bench_utils; - -use crate::bench_utils::bench_name; -use bench_utils::{device, BenchDevice}; +use crate::benchmarks::{bench_name, device, BenchDevice}; use candle_core::{DType, Tensor}; -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; use std::time::Instant; fn run(a: &Tensor, b: &Tensor, c: &Tensor) { @@ -63,4 +60,3 @@ fn criterion_benchmark(c: &mut Criterion) { } criterion_group!(benches, criterion_benchmark); -criterion_main!(benches);