Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal fill kernel #1501

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions candle-core/benches/bench_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod benchmarks;
use criterion::criterion_main;
criterion_main!(
benchmarks::affine::benches,
benchmarks::fill::benches,
benchmarks::matmul::benches,
benchmarks::random::benches,
benchmarks::where_cond::benches
Expand Down
44 changes: 44 additions & 0 deletions candle-core/benches/benchmarks/fill.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;

fn run(shape: (usize, usize, usize), dtype: DType, device: &Device) {
Tensor::ones(shape, dtype, device).unwrap();
}

fn run_fill_benchmark(c: &mut Criterion, device: &Device, name: &str, dtype: DType) {
let b = 1;
let rows = 1024;
let columns = 1024;

let flops = b * rows * columns * dtype.size_in_bytes();

let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |bencher| {
bencher.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(
black_box((b, rows, columns)),
black_box(dtype),
black_box(&device),
);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}

fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_fill_benchmark(c, &device, "fill_u8", DType::U8);
run_fill_benchmark(c, &device, "fill_f32", DType::F32);
}
}

criterion_group!(benches, criterion_benchmark);
1 change: 1 addition & 0 deletions candle-core/benches/benchmarks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub(crate) mod affine;
pub(crate) mod fill;
pub(crate) mod matmul;
pub(crate) mod random;
pub(crate) mod where_cond;
Expand Down
39 changes: 36 additions & 3 deletions candle-core/src/metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
use half::{bf16, f16};
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
Expand Down Expand Up @@ -1591,9 +1592,41 @@ impl BackendDevice for MetalDevice {
}

fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
let buffer = self.new_buffer(shape.elem_count(), dtype, "ones")?;
let command_buffer = self.command_buffer()?;
command_buffer.set_label("ones");

macro_rules! fill {
($value:expr) => {
candle_metal_kernels::call_fill(
&self.device,
&command_buffer,
&self.kernels,
shape.elem_count(),
&buffer,
$value,
)
.map_err(MetalError::from)?
};
}
match dtype {
DType::U8 => candle_metal_kernels::call_fill_u8(
&command_buffer,
shape.elem_count(),
&buffer,
1u8,
)
.map_err(MetalError::from)?,
DType::U32 => fill!(1u32),
DType::I64 => fill!(1i64),
DType::BF16 => fill!(bf16::ONE),
DType::F16 => fill!(f16::ONE),
DType::F32 => fill!(1f32),
DType::F64 => {
return Err(MetalError::Message("Metal doesn't support double".to_string()).into())
}
}
Ok(MetalStorage::new(buffer, self.clone(), dtype))
}

fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
Expand Down
8 changes: 2 additions & 6 deletions candle-metal-kernels/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,13 @@ keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"


[dependencies]
metal = { version = "0.27.0", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
num-traits = "0.2.17"

[dev-dependencies]
half = { version = "2.3.1", features = [
"num-traits",
"use-intrinsics",
"rand_distr",
] }
rand = "0.8.5"
34 changes: 34 additions & 0 deletions candle-metal-kernels/src/fill.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include <metal_stdlib>
using namespace metal;

template<typename T>
void fill(
device T *buffer [[buffer(0)]],
constant T &value,
constant size_t &numel,
uint gid [[thread_position_in_grid]]
) {
if (gid >= numel) return;
buffer[gid] = value;
}

#define FILL_OP(T, FN_NAME) \
kernel void FN_NAME( \
device T *buffer [[buffer(0)]], \
constant T &value, \
constant size_t &numel, \
uint gid [[thread_position_in_grid]] \
) { fill<T>(buffer, value, numel, gid); } \

FILL_OP(uint8_t, fill_u8)
FILL_OP(uint32_t, fill_u32)
FILL_OP(half, fill_f16)
FILL_OP(float, fill_f32)

#if __METAL_VERSION__ >= 220
FILL_OP(int64_t, fill_i64)
#endif

#if defined(__HAVE_BFLOAT__)
FILL_OP(bfloat, fill_bf16)
#endif
114 changes: 96 additions & 18 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use half::{bf16, f16};
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
Expand All @@ -12,6 +13,7 @@ const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const FILL: &str = include_str!("fill.metal");
const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
Expand Down Expand Up @@ -47,29 +49,26 @@ fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64,
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
trait EncoderParam {
pub trait EncoderParam: private::Sealed {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {
($type:ty) => {
impl EncoderParam for $type {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<$type>() as u64,
&data as *const $type as *const c_void,
);

macro_rules! primitives {
($($type:ty),+) => {
$(
impl EncoderParam for $type {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<$type>() as u64,
&data as *const $type as *const c_void,
);
}
}
}
)+
};
}
primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u32);
primitive!(u64);
primitive!(f32);
primitives!(bool, usize, u8, u32, u64, i32, i64, f16, bf16, f32);

impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
Expand Down Expand Up @@ -112,6 +111,22 @@ macro_rules! set_params {
);
}

// Seal for EncoderParam so that only the types we want can implement it
mod private {
use super::*;

pub trait Sealed {}

macro_rules! sealed {
($($type:ty),+) => {
$(impl Sealed for $type {})+
};
}
sealed!(usize, u8, u32, u64, i32, i64, f16, bf16, f32, bool);
sealed!(&Buffer, (&Buffer, usize), &mut Buffer, (&mut Buffer, usize));
impl<T> Sealed for &[T] {}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Affine,
Expand All @@ -123,6 +138,7 @@ pub enum Source {
Reduce,
Mfa,
Conv,
Fill,
Random,
Quantized,
}
Expand Down Expand Up @@ -192,6 +208,8 @@ pub mod binary {

#[derive(thiserror::Error, Debug)]
pub enum MetalKernelError {
#[error("Invalid usage of kernel: {0}")]
InvalidUsage(String),
#[error("Could not lock kernel map: {0}")]
LockError(String),
#[error("Error while loading library: {0}")]
Expand Down Expand Up @@ -244,6 +262,7 @@ impl Kernels {
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Fill => FILL,
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Expand Down Expand Up @@ -1769,9 +1788,68 @@ pub fn call_quantized_matmul_t(
Ok(())
}

#[inline(always)]
fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}

pub fn call_fill<T: FillOp>(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
elem_count: usize,
buffer: &Buffer,
value: T,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, T::FILL_KERNEL)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger);

set_params!(encoder, (buffer, value, elem_count));

let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.end_encoding();

Ok(())
}

pub fn call_fill_u8(
command_buffer: &CommandBufferRef,
elem_count: usize,
buffer: &Buffer,
value: u8,
) -> Result<(), MetalKernelError> {
let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer(
buffer,
metal::NSRange {
location: 0,
length: elem_count as NSUInteger,
},
value,
);
blit.end_encoding();

Ok(())
}

pub trait FillOp: EncoderParam {
const FILL_KERNEL: &'static str;
}

macro_rules ! impl_call_fill {
($($t:ty),*) => {
$(
impl FillOp for $t {
const FILL_KERNEL: &'static str = concat!("fill_", stringify!($t));
}
)*
};
}
impl_call_fill!(u8, u32, i64, f16, bf16, f32);

#[cfg(test)]
mod tests;
36 changes: 36 additions & 0 deletions candle-metal-kernels/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,42 @@ fn gemm() {
);
}

fn run_fill<T: FillOp + Clone>(elem_count: usize, value: T) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let buffer = new_buffer(&device, &vec![0.0f32; elem_count]);
call_fill(
&device,
command_buffer,
&kernels,
elem_count,
&buffer,
value,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();

read_to_vec(&buffer, elem_count)
}

#[test]
fn fill() {
fn assert_fill<T: FillOp + Copy + std::fmt::Debug + PartialEq>(value: T) {
for i in 0..4 {
assert_eq!(run_fill(8 ^ i, value), vec![value; 8 ^ i]);
}
}
assert_fill(123u8);
assert_fill(456u32);
assert_fill(789i64);
assert_fill(f16::from_f32(1.23));
assert_fill(bf16::from_f32(4.56));
assert_fill(7.89f32);
}

fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
Expand Down
1 change: 0 additions & 1 deletion candle-metal-kernels/src/unary.metal
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <metal_stdlib>
#include <metal_math>
#
using namespace metal;

METAL_FUNC uint get_strided_index(
Expand Down
Loading