diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 92c33a8675..4e508a3941 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -1,4 +1,4 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::matmul::benches, benchmarks::where_cond::benches); +criterion_main!(benchmarks::matmul::benches, benchmarks::affine::benches, benchmarks::where_cond::benches); \ No newline at end of file diff --git a/candle-core/benches/benchmarks/affine.rs b/candle-core/benches/benchmarks/affine.rs new file mode 100644 index 0000000000..eded9f5787 --- /dev/null +++ b/candle-core/benches/benchmarks/affine.rs @@ -0,0 +1,43 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor) { + a.affine(12.34, 56.78).unwrap(); +} + +fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let b = 1; + let m = 1024; + let k = 1024; + + let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap(); + + let flops = b * m * k * dtype.size_in_bytes(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&tensor)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_affine_benchmark(c, &device, DType::F32, "affine_f32"); + run_affine_benchmark(c, &device, DType::F16, "affine_f16"); + run_affine_benchmark(c, &device, DType::BF16, "affine_bf16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 4e73ebb67c..7dacff5e36 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod affine; pub(crate) mod matmul; pub(crate) mod where_cond; diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 38f909c859..5269a89960 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -353,6 +353,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32", DType::F16 => "affine_f16", + DType::BF16 => "affine_bf16", dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine( @@ -371,6 +372,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32_strided", DType::F16 => "affine_f16_strided", + DType::BF16 => "affine_bf16_strided", dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine_strided( diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 3d8e7f0da3..a4484998c5 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -17,19 +17,19 @@ METAL_FUNC uint get_strided_index( using namespace metal; -#define AFFINE(FN_NAME, TYPENAME) \ +#define AFFINE(FN_NAME, T) \ kernel void FN_NAME( \ constant size_t &dim, \ constant float &mul, \ constant float &add, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ + device const T *input, \ + device T *output, \ uint id [[ thread_position_in_grid ]] \ ) { \ if (id >= dim) { \ return; \ } \ - output[id] = TYPENAME(float(input[id]) * mul + add); \ + output[id] = T(fma(float(input[id]), mul, add)); \ } \ kernel void FN_NAME##_strided( \ constant size_t &dim, \ @@ -38,14 +38,14 @@ kernel void FN_NAME##_strided( \ constant size_t *strides, \ constant float &mul, \ constant float &add, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ + device const T *input, \ + device T *output, \ uint id [[ thread_position_in_grid ]] \ ) { \ if (id >= dim) { \ return; \ } \ - output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \ + output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \ } #define POWF(FN_NAME, TYPENAME) \