diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8721b272de6a9..ea8cf45863af9 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -299,21 +299,42 @@ typedef struct { } ggml_metal_kargs_mul_mv_ext; typedef struct { - int32_t nei0; - int32_t nei1; - uint64_t nbi1; + int32_t ne10; + int32_t ne11; // n_expert_used (bcast) + uint64_t nb11; + uint64_t nb12; + int32_t neh11; // n_tokens + uint64_t nbh11; + int32_t ne20; // n_expert_used + uint64_t nb21; +} ggml_metal_kargs_mul_mm_id_map0; + +typedef struct { + int32_t ne20; // n_expert_used + int32_t neh0; + int32_t neh1; + uint64_t nbh1; + uint64_t nbh2; + int32_t ne0; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_mul_mm_id_map1; + +typedef struct { int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - int32_t ne0; - int32_t ne1; + uint64_t nb03; + int32_t neh12; + uint64_t nbh10; + uint64_t nbh11; + uint64_t nbh12; + uint64_t nbh13; + int32_t neh0; + int32_t neh1; + int16_t r2; + int16_t r3; } ggml_metal_kargs_mul_mm_id; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index d92392edb7eb1..943f0fe722a2f 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -306,28 +306,30 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, @@ -650,7 +652,8 @@ static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) { } if (mem_pool->heaps_to_remove.count > 0) { - for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) { + // remove in reverse order + for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) { NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue]; ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index]; @@ -659,6 +662,10 @@ static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) { [mem_pool->heaps removeObjectAtIndex:index]; [ptr release]; + + if (i == 0) { + break; + } } [mem_pool->heaps_to_remove removeAllObjects]; @@ -672,7 +679,7 @@ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) { } static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) { - const size_t alignment = 32; + const size_t alignment = 256; const size_t size_aligned = GGML_PAD(size, alignment); @@ -1242,28 +1249,30 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); @@ -2999,7 +3008,7 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { id<MTLComputePipelineState> pipeline = nil; @@ -3219,8 +3228,6 @@ static bool ggml_metal_encode_node( } break; case GGML_OP_MUL_MAT_ID: { - const int n_as = src0->ne[2]; - // src2 = ids const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); @@ -3234,24 +3241,21 @@ static bool ggml_metal_encode_node( GGML_ASSERT(ne03 == 1); GGML_ASSERT(ne13 == 1); + const uint32_t r2 = 1; + const uint32_t r3 = 1; + // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel // ne20 = n_used_experts - // ne21 = n_rows - const int dst_rows = ne20*ne21; - const int dst_rows_min = n_as; - const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4; - - // max size of the rowids array in the kernel shared buffer - //GGML_ASSERT(dst_rows <= dst_rows_max); + // ne21 = n_rows (batch size) + const int ne21_mm_id_min = 32; // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel if ([device supportsFamily:MTLGPUFamilyApple7] && ne00 % 32 == 0 && ne00 >= 64 && - //ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments - dst_rows > dst_rows_min && - dst_rows <= dst_rows_max) { + (ne21 >= ne21_mm_id_min)) { + GGML_ASSERT(ne00 % 4 == 0); // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) @@ -3262,62 +3266,169 @@ static bool ggml_metal_encode_node( default: break; } - id<MTLComputePipelineState> pipeline = nil; + const int64_t neh10 = ne10; // n_embd + const int64_t neh11 = ne21; // n_tokens + const int64_t neh12 = ne02; // n_expert - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; - default: GGML_ABORT("MUL_MAT_ID not implemented"); + const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16); + const uint64_t nbh11 = nbh10*neh10; + const uint64_t nbh12 = nbh11*neh11; + const uint64_t nbh13 = nbh12*neh12; + + const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12; + id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1); + if (!h_src1) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1); + return false; } - ggml_metal_kargs_mul_mm_id args = { - /*.nei0 =*/ ne20, - /*.nei1 =*/ ne21, - /*.nbi1 =*/ nb21, - /*.ne00 =*/ ne00, - /*.ne02 =*/ ne02, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - }; + const int64_t neh0 = ne0; + const int64_t neh1 = ne21; + const int64_t neh2 = ne02; - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; + const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32); + const uint64_t nbh1 = nbh0*neh0; + const uint64_t nbh2 = nbh1*neh1; + //const uint64_t nbh3 = nbh2*neh2; + + const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2; + id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst); + if (!h_dst) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst); + return false; + } + + // tokens per expert + const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02; + id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe); + if (!h_tpe) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe); + return false; + } + + // id map + // [n_expert_used, n_tokens] + const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21; + id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids); + if (!h_ids) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids); + return false; + } + + { + const int nth = MIN(1024, ne10/4); + + ggml_metal_kargs_mul_mm_id_map0 args = { + ne10, + ne11, // n_expert_used (bcast) + nb11, + nb12, + neh11, // n_tokens + nbh11, + ne20, // n_expert_used + nb21, + }; + + id<MTLComputePipelineState> pipeline = nil; + + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer: h_src1 offset:0 atIndex:3]; + [encoder setBuffer: h_tpe offset:0 atIndex:4]; + [encoder setBuffer: h_ids offset:0 atIndex:5]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + + { + id<MTLComputePipelineState> pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break; + default: GGML_ABORT("MUL_MAT_ID not implemented"); + } + + ggml_metal_kargs_mul_mm_id args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.neh12 =*/ neh12, + /*.nbh10 =*/ nbh10, + /*.nbh11 =*/ nbh11, + /*.nbh12 =*/ nbh12, + /*.nbh13 =*/ nbh13, + /*.neh0 =*/ neh0, + /*.neh1 =*/ neh1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer: h_src1 offset:0 atIndex:2]; + [encoder setBuffer: h_tpe offset:0 atIndex:3]; + [encoder setBuffer: h_dst offset:0 atIndex:4]; + + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } - [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; + { + GGML_ASSERT(ne0 % 4 == 0); - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + const int nth = MIN(1024, ne0/4); + + ggml_metal_kargs_mul_mm_id_map1 args = { + ne20, // n_expert_used + neh0, + neh1, + nbh1, + nbh2, + ne0, + nb1, + nb2, + }; + + id<MTLComputePipelineState> pipeline = nil; + + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer: h_dst offset:0 atIndex:1]; + [encoder setBuffer: h_ids offset:0 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } } else { id<MTLComputePipelineState> pipeline = nil; @@ -3511,7 +3622,7 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; - const int64_t ne123 = dst_rows; + const int64_t ne123 = ne20*ne21; if (smem > 0) { [encoder setThreadgroupMemoryLength:smem atIndex:0]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9f4147e93974d..a808e79c22752 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -6336,127 +6336,219 @@ kernel void kernel_mul_mm( } } -// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids -// TODO: this kernel needs to be reimplemented from scratch for better performance -template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> -void kernel_mul_mm_id_impl( - int32_t ne00, - int32_t ne02, - uint64_t nb01, - uint64_t nb02, - int32_t ne11, - int32_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int32_t ne0, - int32_t ne1, - int64_t ne0ne1, - device const char * src0, - device const char * src1, - threadgroup ushort2 * rowids, - device char * dst, - threadgroup char * shmem, +template<typename T4> +kernel void kernel_mul_mm_id_map0( + constant ggml_metal_kargs_mul_mm_id_map0 & args, + device const char * src1, + device const char * src2, + device char * hsrc1, + device char * htpe, + device char * hids, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int ide = tgpig[0]; // expert id + + int n_all = 0; + + device int32_t * ids_i32 = (device int32_t *) (hids); + + for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens + device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21); + + for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used + if (src2_i32[i20] != ide) { + continue; + } + + device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11); + device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11); + + for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) { + hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]); + } + + if (tpitg.x == 0) { + ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all; + } + + ++n_all; + } + } + + if (tpitg.x == 0) { + device int32_t * tpe_i32 = (device int32_t *) (htpe); + tpe_i32[ide] = n_all; + } +} + +typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t; + +template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>; + +template<typename T> +kernel void kernel_mul_mm_id_map1( + constant ggml_metal_kargs_mul_mm_id_map1 & args, + device const char * hdst, + device const char * hids, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i20 = tgpig[0]; // used expert + const int i21 = tgpig[1]; // token + + device const int32_t * ids_i32 = (device const int32_t *) (hids); + device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2); + + const int id = ids_i32[i21*args.ne20 + i20]; + + const int ide = id / args.neh1; + const int idt = id % args.neh1; + + device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2); + + for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) { + dst_f32x4[i0] = hdst_f32x4[i0]; + } +} + +typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t; + +template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>; + +template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)> +kernel void kernel_mul_mm_id( + constant ggml_metal_kargs_mul_mm_id & args, + device const char * src0, + device const char * src1, + device const char * tpe, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup half * sa = (threadgroup half *)(shmem); - threadgroup float * sb = (threadgroup float *)(shmem + 4096); + threadgroup T * sa = (threadgroup T *)(shmem); + threadgroup half * sb = (threadgroup half *)(shmem + 4096); const int r0 = tgpig.y; const int r1 = tgpig.x; + const int im = tgpig.z; + + device const int32_t * tpe_i32 = (device const int32_t *) (tpe); + + const int neh1 = tpe_i32[im]; - if (r1*BLOCK_SIZE_N >= ne1) return; + if (r1*BLOCK_SIZE_N >= neh1) { + return; + } // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; + simdgroup_T8x8 ma[4]; + simdgroup_half8x8 mb[2]; simdgroup_float8x8 mc[8]; - for (int i = 0; i < 8; i++){ + + for (short i = 0; i < 8; i++){ mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f); } + short il = (tiitg % THREAD_PER_ROW); - ushort offset1 = il/nl; + const int i12 = im%args.neh12; + const int i13 = im/args.neh12; - threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * id[1] - + nb11 * (id[0] % ne11) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + device const block_q * x = (device const block_q *)(src0 + + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + device const half * y = (device const half *)(src1 + + args.nbh13*i13 + + args.nbh12*i12 + + args.nbh11*(r1*BLOCK_SIZE_N + thread_col) + + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory - half4x4 temp_a; + T4x4 temp_a; dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + #pragma unroll(16) + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; } - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y); il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; + x = (il < 2) ? x + (2 + nl - 1)/nl : x; y += BLOCK_SIZE_K; threadgroup_barrier(mem_flags::mem_threadgroup); // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); - #pragma unroll(BLOCK_SIZE_K/8) - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { #pragma unroll(4) - for (int i = 0; i < 4; i++) { + for (short i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) - for (int i = 0; i < 2; i++) { + for (short i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - #pragma unroll(8) - for (int i = 0; i < 8; i++){ + for (short i = 0; i < 8; i++){ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); } + + lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE; + lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE; } } - { + if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) { + device float * C = (device float *) dst + + (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ + (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0; + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup float * temp_str = ((threadgroup float *) shmem) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); } threadgroup_barrier(mem_flags::mem_threadgroup); if (sgitg == 0) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; - int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1; - - device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff; + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0; device float4 * D4 = (device float4 *) D; threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); @@ -6476,66 +6568,6 @@ void kernel_mul_mm_id_impl( } } -template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> -kernel void kernel_mul_mm_id( - constant ggml_metal_kargs_mul_mm_id & args, - device const char * src0s, - device const char * src1, - device char * dst, - device const char * ids, - threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - const int32_t i02 = tgpig.z; - - tgpig.z = 0; - - device const char * src0 = src0s + i02*args.nb02; - - // row indices - threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); - - // TODO: parallelize this loop - int32_t _ne1 = 0; - for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { - for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { - int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0]; - if (id == i02) { - if (tiitg == 0) { - rowids[_ne1] = ushort2(ii0, ii1); - } - _ne1++; - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - kernel_mul_mm_id_impl<block_q, nl, dequantize_func>( - args.ne00, - args.ne02, - args.nb01, - args.nb02, - args.ne11, - args.ne12, - args.nb10, - args.nb11, - args.nb12, - args.ne0, - _ne1, - (int64_t)args.ne0*args.ne1, - src0, - src1, - rowids, - dst, - shmem, - tgpig, - tiitg, - sgitg); -} - #define QK_NL 16 // @@ -6576,63 +6608,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get // matrix-matrix multiplication // -typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t; +typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_t; -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>; +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>; #endif -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>; -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>; // // indirect matrix-matrix multiplication // -typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t; +typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_id; -template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>; -template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>; +template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>; +template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>; +template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>; #endif -template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>; -template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>; -template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>; -template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>; -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>; -template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>; -template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>; -template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; -template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>; -template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; -template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>; -template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>; -template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>; -template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>; -template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>; -template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>; +template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>; +template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>; +template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>; +template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>; +template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>; +template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>; +template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>; +template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>; +template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>; +template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; +template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; +template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>; +template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>; +template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>; +template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>; +template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>; +template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>; + // // matrix-vector multiplication diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ee4fe9f723dc1..bc673292b37a3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2732,11 +2732,11 @@ void ggml_mul_mat_set_prec( c = ggml_mul_mat_id(ctx, as, b, ids); as -> [cols, rows, n_expert] - ids -> [n_experts_used, n_tokens] (i32) b -> [cols, n_expert_used, n_tokens] + ids -> [n_expert_used, n_tokens] (i32) c -> [rows, n_expert_used, n_tokens] - in b, n_experts_used can be broadcasted to match the n_expert_used of ids + in b, n_expert_used can be broadcasted to match the n_expert_used of ids c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids */