Skip to content

Commit

Permalink
feat(bf16): add cast support + tests for cast + bin ops (#1524)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyle-mccarthy authored Jan 11, 2024
1 parent 9f0c99f commit 402349d
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 18 deletions.
1 change: 0 additions & 1 deletion candle-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,3 @@ metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]]
name = "bench_main"
harness = false

54 changes: 52 additions & 2 deletions candle-core/src/metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,14 +590,26 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::BF16) => "cast_u32_bf16",

(DType::U8, DType::U32) => "cast_u8_u32",
(DType::U8, DType::F32) => "cast_u8_f32",
(DType::U8, DType::I64) => "cast_u8_i64",
(DType::U8, DType::BF16) => "cast_u8_bf16",

(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F16, DType::F32) => "cast_f16_f32",
(DType::I64, DType::F32) => "cast_i64_f32",
(DType::F32, DType::BF16) => "cast_f32_bf16",

(DType::I64, DType::F32) => "cast_i64_f32",

(DType::F16, DType::BF16) => "cast_f16_bf16",
(DType::F16, DType::F32) => "cast_f16_f32",

(DType::BF16, DType::U8) => "cast_bf16_u8",
(DType::BF16, DType::U32) => "cast_bf16_u32",
(DType::BF16, DType::F16) => "cast_bf16_f16",
(DType::BF16, DType::F32) => "cast_bf16_f32",

(left, right) => {
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
}
Expand Down Expand Up @@ -1131,8 +1143,12 @@ impl BackendStorage for MetalStorage {
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::BF16) => "is_u8_bf16",

(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16",

(left, right) => {
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
}
Expand Down Expand Up @@ -1322,6 +1338,7 @@ impl MetalStorage {
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),

("add", DType::F16) => (contiguous::add::HALF, self.dtype),
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
Expand All @@ -1332,6 +1349,18 @@ impl MetalStorage {
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),

("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype),
("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype),
("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype),
("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype),
("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8),
("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8),
("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8),
("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8),
("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8),
("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8),

("add", DType::I64) => (contiguous::add::I64, self.dtype),
("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
Expand All @@ -1342,6 +1371,7 @@ impl MetalStorage {
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),

("add", DType::U32) => (contiguous::add::U32, self.dtype),
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
Expand All @@ -1352,6 +1382,7 @@ impl MetalStorage {
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),

("add", DType::U8) => (contiguous::add::U8, self.dtype),
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
Expand All @@ -1362,6 +1393,7 @@ impl MetalStorage {
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),

(name, dtype) => {
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
}
Expand Down Expand Up @@ -1395,6 +1427,7 @@ impl MetalStorage {
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),

("badd", DType::F16) => (strided::add::HALF, self.dtype),
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
Expand All @@ -1407,6 +1440,20 @@ impl MetalStorage {
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
("gt", DType::F16) => (strided::gt::HALF, DType::U8),

("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype),
("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype),
("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype),
("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype),
("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype),
("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype),
("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8),
("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8),
("le", DType::BF16) => (strided::le::BFLOAT, DType::U8),
("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8),
("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8),
("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8),

("badd", DType::I64) => (strided::add::I64, self.dtype),
("bsub", DType::I64) => (strided::sub::I64, self.dtype),
("bmul", DType::I64) => (strided::mul::I64, self.dtype),
Expand All @@ -1419,6 +1466,7 @@ impl MetalStorage {
("lt", DType::I64) => (strided::lt::I64, DType::U8),
("ge", DType::I64) => (strided::ge::I64, DType::U8),
("gt", DType::I64) => (strided::gt::I64, DType::U8),

("badd", DType::U32) => (strided::add::U32, self.dtype),
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
Expand All @@ -1431,6 +1479,7 @@ impl MetalStorage {
("lt", DType::U32) => (strided::lt::U32, DType::U8),
("ge", DType::U32) => (strided::ge::U32, DType::U8),
("gt", DType::U32) => (strided::gt::U32, DType::U8),

("badd", DType::U8) => (strided::add::U8, self.dtype),
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
Expand All @@ -1443,6 +1492,7 @@ impl MetalStorage {
("lt", DType::U8) => (strided::lt::U8, DType::U8),
("ge", DType::U8) => (strided::ge::U8, DType::U8),
("gt", DType::U8) => (strided::gt::U8, DType::U8),

(name, dtype) => {
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
}
Expand Down
9 changes: 7 additions & 2 deletions candle-metal-kernels/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@ keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"


[dependencies]
metal = { version = "0.27.0", features = ["mps"]}
metal = { version = "0.27.0", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

[dev-dependencies]
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
half = { version = "2.3.1", features = [
"num-traits",
"use-intrinsics",
"rand_distr",
] }
rand = "0.8.5"
40 changes: 38 additions & 2 deletions candle-metal-kernels/src/cast.metal
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ kernel void FN_NAME( \
if (tid >= dim) { \
return; \
} \
output[tid] = RIGHT_TYPENAME(input[tid]); \
output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
Expand All @@ -42,7 +42,34 @@ kernel void FN_NAME_STRIDED( \
if (tid >= dim) { \
return; \
} \
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \
} \

#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
} \

CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
Expand All @@ -59,6 +86,15 @@ CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
#endif

#if defined(__HAVE_BFLOAT__)
#if __METAL_VERSION__ >= 310
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)

CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)

CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif
3 changes: 3 additions & 0 deletions candle-metal-kernels/src/indexing.metal
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ SCATTER_ADD_OP(sa_u32_f16, uint, half)


#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
INDEX_OP(is_u8_bf16, uint8_t, bfloat)

INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
Expand Down
Loading

0 comments on commit 402349d

Please sign in to comment.