Skip to content

Commit

Permalink
Use __HAVE_BFLOAT__ to check for bfloat support instead of metal vers…
Browse files Browse the repository at this point in the history
…ion check (#1540)
  • Loading branch information
ivarflakstad authored Jan 10, 2024
1 parent ae06cb7 commit d3bdd78
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion candle-metal-kernels/src/affine.metal
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ ELU(elu_f32, float)
ELU(elu_f16, half)


#if __METAL_VERSION__ >= 310
#if defined(__HAVE_BFLOAT__)
AFFINE(affine_bf16, bfloat);
POWF(powf_bf16, bfloat);
ELU(elu_bf16, bfloat);
Expand Down
2 changes: 1 addition & 1 deletion candle-metal-kernels/src/binary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y)
INT64_BINARY_OP_OUT(gt, x > y)
#endif

#if __METAL_VERSION__ >= 310
#if defined(__HAVE_BFLOAT__)
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)
Expand Down
2 changes: 1 addition & 1 deletion candle-metal-kernels/src/cast.metal
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
#endif

#if __METAL_VERSION__ >= 310
#if defined(__HAVE_BFLOAT__)
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
#endif
2 changes: 1 addition & 1 deletion candle-metal-kernels/src/indexing.metal
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float)
SCATTER_ADD_OP(sa_u32_f16, uint, half)


#if __METAL_VERSION__ >= 310
#if defined(__HAVE_BFLOAT__)
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
Expand Down
2 changes: 1 addition & 1 deletion candle-metal-kernels/src/reduce.metal
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
#endif

#if __METAL_VERSION__ >= 310
#if defined(__HAVE_BFLOAT__)
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
Expand Down
2 changes: 1 addition & 1 deletion candle-metal-kernels/src/unary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
UNARY(id, int64_t, copy_i64, copy_i64_strided)
#endif

#if __METAL_VERSION__ >= 310
#if defined(__HAVE_BFLOAT__)
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)
BFLOAT_UNARY_OP(sqr)
Expand Down

0 comments on commit d3bdd78

Please sign in to comment.