Skip to content

Commit

Permalink
Metal: Activate bfloat affine and add benchmark (#1543)
Browse files Browse the repository at this point in the history
* Use cfg to seperate benchmark results based on features

* Add bfloat affine and benchmarks

* Fix flops calculation

* Remove allow pragma

* Avoid some unnecessary returns.

* Improve benchmarks layout

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
  • Loading branch information
3 people authored Jan 12, 2024
1 parent e90bcdc commit a3d92ab
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 8 deletions.
2 changes: 1 addition & 1 deletion candle-core/benches/bench_main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mod benchmarks;

use criterion::criterion_main;
criterion_main!(benchmarks::matmul::benches, benchmarks::where_cond::benches);
criterion_main!(benchmarks::matmul::benches, benchmarks::affine::benches, benchmarks::where_cond::benches);
43 changes: 43 additions & 0 deletions candle-core/benches/benchmarks/affine.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;

fn run(a: &Tensor) {
a.affine(12.34, 56.78).unwrap();
}

fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let b = 1;
let m = 1024;
let k = 1024;

let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();

let flops = b * m * k * dtype.size_in_bytes();

let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}

fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
}
}

criterion_group!(benches, criterion_benchmark);
1 change: 1 addition & 0 deletions candle-core/benches/benchmarks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub(crate) mod affine;
pub(crate) mod matmul;
pub(crate) mod where_cond;

Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "affine_f32",
DType::F16 => "affine_f16",
DType::BF16 => "affine_bf16",
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
};
candle_metal_kernels::call_affine(
Expand All @@ -371,6 +372,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "affine_f32_strided",
DType::F16 => "affine_f16_strided",
DType::BF16 => "affine_bf16_strided",
dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"),
};
candle_metal_kernels::call_affine_strided(
Expand Down
14 changes: 7 additions & 7 deletions candle-metal-kernels/src/affine.metal
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@ METAL_FUNC uint get_strided_index(

using namespace metal;

#define AFFINE(FN_NAME, TYPENAME) \
#define AFFINE(FN_NAME, T) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
device const T *input, \
device T *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
output[id] = TYPENAME(float(input[id]) * mul + add); \
output[id] = T(fma(float(input[id]), mul, add)); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
Expand All @@ -38,14 +38,14 @@ kernel void FN_NAME##_strided( \
constant size_t *strides, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
device const T *input, \
device T *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \
}

#define POWF(FN_NAME, TYPENAME) \
Expand Down

0 comments on commit a3d92ab

Please sign in to comment.