-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Metal: Improved reduce and softmax (#1819)
* Improve reduce perf and add contiguous impl * Improve arg reduce and add contiguous impl * Improve softmax kernel. 33%-39% higher thrpt * fmt * Fixed all bugs. Improved code quality. Added tests. * Stash for debugging * Stash for debugging 2 * Fixing argmax bug and improve performance Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> * Fix test and add is_valid_simgroup_reduce_type trait * Online softmax. Improved threadgroup reduce. Tidying up a bit. * Remove redundant threadgroup_barrier from arg reduce * Mostly tidying up. Some improvements * Simplify indexed struct * tidying * Reuse operation operator instead of passing it in as a parameter * Fix how operators are applied to indexed<vec<T,N>> * Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce. * Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal. * Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math * Use constant for input instead of const device. Fix strided reduce. * Use contiguous reduce in tests * Rename finalize -> to_scalar * Support integer types max/min (switch with trait-inferred impl later) * Was worried I was skipping work -> shuffling the 1D test cases * Add build.rs to avoid metal kernel jit compile overhead * Improve build. Extract utils * Compile metal kernels for both macos and ios * Fixed over xmas and then forgot about it * Add calculate_reduce_threads util * Remove old reduce.metal * Improve f16/bf16 softmax precision by accumulating in f32 * Remove build.rs (for now) * Move softmax bench to candle-nn * Remove redundant thread calc util fn * Use uint over ushort for indices etc * Use fast exp in MDReduceOp * Remove nested metal define for softmax * Fix some clippy lint. --------- Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> Co-authored-by: Laurent <laurent.mazare@gmail.com>
- Loading branch information
1 parent
0af3e42
commit 7c2449f
Showing
12 changed files
with
1,500 additions
and
336 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; | ||
use candle_core::{DType, Device, Tensor}; | ||
use criterion::{black_box, criterion_group, Criterion, Throughput}; | ||
use half::{bf16, f16}; | ||
use std::time::Instant; | ||
|
||
fn run_sum(a: &Tensor) { | ||
a.sum_keepdim(2).unwrap(); | ||
} | ||
fn run_arg_min(a: &Tensor) { | ||
a.argmin_keepdim(2).unwrap(); | ||
} | ||
|
||
fn criterion_benchmark(c: &mut Criterion) { | ||
let handler = BenchDeviceHandler::new().unwrap(); | ||
let (lo, up) = (-1000.0f32, 1000.0f32); | ||
for device in handler.devices { | ||
run_reduce(c, &device, (lo, up), false); | ||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false); | ||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false); | ||
|
||
run_arg_reduce(c, &device, (lo, up), false); | ||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false); | ||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false); | ||
|
||
run_reduce(c, &device, (lo, up), true); | ||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true); | ||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true); | ||
|
||
run_arg_reduce(c, &device, (lo, up), true); | ||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true); | ||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true); | ||
} | ||
} | ||
|
||
fn run_reduce<T: candle_core::FloatDType>( | ||
c: &mut Criterion, | ||
device: &Device, | ||
(lo, up): (T, T), | ||
strided: bool, | ||
) { | ||
let b = 1; | ||
let m = 1024; | ||
let k = 1024; | ||
|
||
let a = if strided { | ||
Tensor::rand(lo, up, (b, m, k), &device) | ||
.unwrap() | ||
.transpose(0, 2) | ||
.unwrap() | ||
} else { | ||
Tensor::rand(lo, up, (b, m, k), &device).unwrap() | ||
}; | ||
|
||
let flops = b * m * k * T::DTYPE.size_in_bytes(); | ||
|
||
let name = match T::DTYPE { | ||
DType::F32 => { | ||
if strided { | ||
"reduce_f32_strided" | ||
} else { | ||
"reduce_f32" | ||
} | ||
} | ||
DType::F16 => { | ||
if strided { | ||
"reduce_f16_strided" | ||
} else { | ||
"reduce_f16" | ||
} | ||
} | ||
DType::BF16 => { | ||
if strided { | ||
"reduce_bf16_strided" | ||
} else { | ||
"reduce_bf16" | ||
} | ||
} | ||
_ => "unknown", | ||
}; | ||
|
||
let mut group = c.benchmark_group(device.bench_name(name)); | ||
group.throughput(Throughput::Bytes(flops as u64)); | ||
group.bench_function("iter", move |b| { | ||
b.iter_custom(|iters| { | ||
let start = Instant::now(); | ||
for _i in 0..iters { | ||
run_sum(black_box(&a)); | ||
} | ||
device.sync().unwrap(); | ||
start.elapsed() | ||
}) | ||
}); | ||
group.finish(); | ||
} | ||
|
||
fn run_arg_reduce<T: candle_core::FloatDType>( | ||
c: &mut Criterion, | ||
device: &Device, | ||
(lo, up): (T, T), | ||
strided: bool, | ||
) { | ||
let b = 1; | ||
let m = 1024; | ||
let k = 1024; | ||
|
||
let a = if strided { | ||
Tensor::rand(lo, up, (b, m, k), &device) | ||
.unwrap() | ||
.transpose(0, 2) | ||
.unwrap() | ||
} else { | ||
Tensor::rand(lo, up, (b, m, k), &device).unwrap() | ||
}; | ||
|
||
let flops = b * m * k * T::DTYPE.size_in_bytes(); | ||
|
||
let name = match T::DTYPE { | ||
DType::F32 => { | ||
if strided { | ||
"arg_reduce_f32_strided" | ||
} else { | ||
"arg_reduce_f32" | ||
} | ||
} | ||
DType::F16 => { | ||
if strided { | ||
"arg_reduce_f16_strided" | ||
} else { | ||
"arg_reduce_f16" | ||
} | ||
} | ||
DType::BF16 => { | ||
if strided { | ||
"arg_reduce_bf16_strided" | ||
} else { | ||
"arg_reduce_bf16" | ||
} | ||
} | ||
_ => "unknown", | ||
}; | ||
|
||
let mut group = c.benchmark_group(device.bench_name(name)); | ||
group.throughput(Throughput::Bytes(flops as u64)); | ||
group.bench_function("iter", move |b| { | ||
b.iter_custom(|iters| { | ||
let start = Instant::now(); | ||
for _i in 0..iters { | ||
run_arg_min(black_box(&a)); | ||
} | ||
device.sync().unwrap(); | ||
start.elapsed() | ||
}) | ||
}); | ||
group.finish(); | ||
} | ||
|
||
criterion_group!(benches, criterion_benchmark); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.