Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal fill kernel #1501

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion candle-core/benches/bench_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ use criterion::criterion_main;
criterion_main!(
benchmarks::matmul::benches,
benchmarks::affine::benches,
benchmarks::where_cond::benches
benchmarks::fill::benches,
benchmarks::where_cond::benches,
);
44 changes: 44 additions & 0 deletions candle-core/benches/benchmarks/fill.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;

fn run(shape: (usize, usize, usize), dtype: DType, device: &Device) {
Tensor::ones(shape, dtype, device).unwrap();
}

fn run_fill_benchmark(c: &mut Criterion, device: &Device, name: &str, dtype: DType) {
let b = 1;
let rows = 1024;
let columns = 1024;

let flops = b * rows * columns * dtype.size_in_bytes();

let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |bencher| {
bencher.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(
black_box((b, rows, columns)),
black_box(dtype),
black_box(&device),
);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}

fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_fill_benchmark(c, &device, "fill_u8", DType::U8);
run_fill_benchmark(c, &device, "fill_f32", DType::F32);
}
}

criterion_group!(benches, criterion_benchmark);
1 change: 1 addition & 0 deletions candle-core/benches/benchmarks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub(crate) mod affine;
pub(crate) mod fill;
pub(crate) mod matmul;
pub(crate) mod where_cond;

Expand Down
39 changes: 36 additions & 3 deletions candle-core/src/metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
use half::{bf16, f16};
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
Expand Down Expand Up @@ -1586,9 +1587,41 @@ impl BackendDevice for MetalDevice {
}

fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
let buffer = self.new_buffer(shape.elem_count(), dtype, "ones")?;
let command_buffer = self.command_buffer()?;
command_buffer.set_label("ones");

macro_rules! fill {
($value:expr) => {
candle_metal_kernels::call_fill(
&self.device,
&command_buffer,
&self.kernels,
shape.elem_count(),
&buffer,
$value,
)
.map_err(MetalError::from)?
};
}
match dtype {
DType::U8 => candle_metal_kernels::call_fill_u8(
&command_buffer,
shape.elem_count(),
&buffer,
1u8,
)
.map_err(MetalError::from)?,
DType::U32 => fill!(1u32),
DType::I64 => fill!(1i64),
DType::BF16 => fill!(bf16::ONE),
DType::F16 => fill!(f16::ONE),
DType::F32 => fill!(1f32),
DType::F64 => {
return Err(MetalError::Message(format!("metal doesn't support double")).into())
}
}
Ok(MetalStorage::new(buffer, self.clone(), dtype))
}

fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
Expand Down
8 changes: 2 additions & 6 deletions candle-metal-kernels/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,13 @@ keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"


[dependencies]
metal = { version = "0.27.0", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
num-traits = "0.2.17"

[dev-dependencies]
half = { version = "2.3.1", features = [
"num-traits",
"use-intrinsics",
"rand_distr",
] }
rand = "0.8.5"
34 changes: 34 additions & 0 deletions candle-metal-kernels/src/fill.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include <metal_stdlib>
using namespace metal;

template<typename T>
void fill(
device T *buffer [[buffer(0)]],
constant T &value,
constant size_t &numel,
uint gid [[thread_position_in_grid]]
) {
if (gid >= numel) return;
buffer[gid] = value;
}

#define FILL_OP(T, FN_NAME) \
kernel void FN_NAME( \
device T *buffer [[buffer(0)]], \
constant T &value, \
constant size_t &numel, \
uint gid [[thread_position_in_grid]] \
) { fill<T>(buffer, value, numel, gid); } \

FILL_OP(uint8_t, fill_u8)
FILL_OP(uint32_t, fill_u32)
FILL_OP(half, fill_f16)
FILL_OP(float, fill_f32)

#if __METAL_VERSION__ >= 220
FILL_OP(int64_t, fill_i64)
#endif

#if defined(__HAVE_BFLOAT__)
FILL_OP(bfloat, fill_bf16)
#endif
74 changes: 71 additions & 3 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use half::{bf16, f16};
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
Expand All @@ -12,6 +13,7 @@ const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const FILL: &str = include_str!("fill.metal");
const REDUCE: &str = include_str!("reduce.metal");
const CONV: &str = include_str!("conv.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
Expand Down Expand Up @@ -46,7 +48,7 @@ fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64,
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
trait EncoderParam {
pub trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {
Expand All @@ -63,9 +65,12 @@ macro_rules! primitive {
};
}
primitive!(usize);
primitive!(i64);
primitive!(i32);
primitive!(u8);
primitive!(u32);
primitive!(i32);
primitive!(i64);
primitive!(f16);
primitive!(bf16);
primitive!(f32);

impl<T> EncoderParam for &[T] {
Expand Down Expand Up @@ -120,6 +125,7 @@ pub enum Source {
Reduce,
Mfa,
Conv,
Fill,
Quantized,
}

Expand Down Expand Up @@ -188,6 +194,8 @@ pub mod binary {

#[derive(thiserror::Error, Debug)]
pub enum MetalKernelError {
#[error("Invalid usage of kernel: {0}")]
InvalidUsage(String),
#[error("Could not lock kernel map: {0}")]
LockError(String),
#[error("Error while loading library: {0}")]
Expand Down Expand Up @@ -240,6 +248,7 @@ impl Kernels {
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Fill => FILL,
Source::Conv => CONV,
Source::Quantized => QUANTIZED,
Source::Mfa => panic!("Invalid lib"),
Expand Down Expand Up @@ -1697,9 +1706,68 @@ pub fn call_quantized_matmul_t(
Ok(())
}

#[inline(always)]
fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}

pub fn call_fill<T: FillOp>(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
elem_count: usize,
buffer: &Buffer,
value: T,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, T::FILL_KERNEL)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger);

set_params!(encoder, (buffer, value, elem_count));

let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.end_encoding();

Ok(())
}

pub fn call_fill_u8(
command_buffer: &CommandBufferRef,
elem_count: usize,
buffer: &Buffer,
value: u8,
) -> Result<(), MetalKernelError> {
let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer(
buffer,
metal::NSRange {
location: 0,
length: elem_count as NSUInteger,
},
value,
);
blit.end_encoding();

Ok(())
}

pub trait FillOp: EncoderParam {
const FILL_KERNEL: &'static str;
}

macro_rules ! impl_call_fill {
($($t:ty),*) => {
$(
impl FillOp for $t {
const FILL_KERNEL: &'static str = concat!("fill_", stringify!($t));
}
)*
};
}
impl_call_fill!(u32, i64, f16, bf16, f32);

#[cfg(test)]
mod tests;
46 changes: 44 additions & 2 deletions candle-metal-kernels/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {

fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const core::ffi::c_void;
let ptr = data.as_ptr() as *const c_void;
let size = (data.len() * std::mem::size_of::<T>()) as u64;
device.new_buffer_with_data(ptr, size, options)
}
Expand Down Expand Up @@ -713,7 +713,6 @@ fn softmax() {
}
let results = run_softmax(&v, last_dim, "softmax_f32");
let results = approx(results, 4);
println!("{results:?}");
assert_eq!(
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
n
Expand Down Expand Up @@ -927,3 +926,46 @@ fn gemm() {
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);
}

fn run_fill<T: EncoderParam + Clone>(elem_count: usize, value: T) -> Vec<T>
where
Unary<T>: FillOp<T>,
{
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let buffer = new_buffer(&device, &vec![0.0f32; elem_count]);
Unary::<T>::fill(
&device,
command_buffer,
&kernels,
elem_count,
&buffer,
value,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();

read_to_vec(&buffer, elem_count)
}

#[test]
fn fill() {
fn assert_fill<T: EncoderParam + Copy + std::fmt::Debug + PartialEq>(value: T)
where
Unary<T>: FillOp<T>,
{
for i in 0..4 {
assert_eq!(run_fill(8 ^ i, value), vec![value; 8 ^ i]);
}
}
assert_fill(123u8);
assert_fill(456u32);
assert_fill(789i64);
assert_fill(f16::from_f32(1.23));
assert_fill(bf16::from_f32(4.56));
assert_fill(7.89f32);
}
1 change: 0 additions & 1 deletion candle-metal-kernels/src/unary.metal
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <metal_stdlib>
#include <metal_math>
#
using namespace metal;

METAL_FUNC uint get_strided_index(
Expand Down
Loading