From 1a32107fab4dd47870fc21ac740a8b67fdd31737 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 25 Jan 2025 23:31:03 +0100 Subject: [PATCH] Add a few metal gather ops. (#2740) * Add a few metal gather ops. * Fix some compilation issues. * Adjust the tolerance. --- candle-core/src/metal_backend/mod.rs | 6 ++++++ candle-metal-kernels/src/indexing.metal | 6 ++++++ candle-metal-kernels/src/lib.rs | 4 ++-- candle-metal-kernels/src/scaled_dot_product_attention.metal | 4 ++-- candle-nn/tests/sdpa.rs | 2 +- 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index bffba50db8..435b2ec549 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1245,6 +1245,12 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F16) => "gather_u32_f16", (DType::U32, DType::BF16) => "gather_u32_bf16", (DType::U32, DType::U32) => "gather_u32_u32", + (DType::U32, DType::I64) => "gather_u32_i64", + (DType::I64, DType::F32) => "gather_i64_f32", + (DType::I64, DType::F16) => "gather_i64_f16", + (DType::I64, DType::BF16) => "gather_i64_bf16", + (DType::I64, DType::U32) => "gather_i64_u32", + (DType::I64, DType::I64) => "gather_i64_i64", (left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"), }; let command_buffer = self.device.command_buffer()?; diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 7509b62803..df374d20d6 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -209,12 +209,18 @@ INDEX_OP(is_u8_f16, uint8_t, half) INDEX_OP(is_u8_bf16, uint8_t, bfloat) #endif +GATHER_OP(gather_i64_f32, int64_t, float) +GATHER_OP(gather_i64_f16, int64_t, half) GATHER_OP(gather_u32_f32, uint, float) GATHER_OP(gather_u32_f16, uint, half) #if defined(__HAVE_BFLOAT__) +GATHER_OP(gather_i64_bf16, int64_t, bfloat) GATHER_OP(gather_u32_bf16, uint, bfloat) #endif +GATHER_OP(gather_i64_u32, int64_t, uint) GATHER_OP(gather_u32_u32, uint, uint) +GATHER_OP(gather_i64_i64, int64_t, int64_t) +GATHER_OP(gather_u32_i64, uint, int64_t) SCATTER_ADD_OP(sa_u32_f32, uint32_t, float) SCATTER_ADD_OP(sa_u8_f32, uint8_t, float) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 818e4a0264..79cfb99035 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2029,7 +2029,7 @@ pub fn call_sdpa_vector_2pass( )])); let pipeline = - kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?; + kernels.load_pipeline_with_constants(device, Source::Sdpa, name_pass1, constants)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -2104,7 +2104,7 @@ pub fn call_sdpa_vector_2pass( let b = (q_shape[0] * q_shape[1]) as i32; - let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass2)?; + let pipeline = kernels.load_pipeline(device, Source::Sdpa, name_pass2)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal index 0453e0d11a..ab129d13a1 100644 --- a/candle-metal-kernels/src/scaled_dot_product_attention.metal +++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal @@ -1404,7 +1404,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); const constant size_t& v_stride, \ const constant float& scale, \ const constant float& softcapping, \ - const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \ + const device bool* mask [[function_constant(sdpa_vector_has_mask)]], \ const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \ const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \ uint3 tid [[threadgroup_position_in_grid]], \ @@ -1424,7 +1424,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); const constant size_t& v_stride, \ const constant float& scale, \ const constant float& softcapping, \ - const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \ + const device bool* mask [[function_constant(sdpa_vector_has_mask)]], \ const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \ const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \ uint3 tid [[threadgroup_position_in_grid]], \ diff --git a/candle-nn/tests/sdpa.rs b/candle-nn/tests/sdpa.rs index 67ad3816b4..664d68dcef 100644 --- a/candle-nn/tests/sdpa.rs +++ b/candle-nn/tests/sdpa.rs @@ -116,7 +116,7 @@ mod metal_sdpa_tests { .sum_all()? .to_scalar()?; - assert!(error <= 0.0004, "{}", error); + assert!(error <= 0.0005, "{}", error); Ok(()) }