diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
index 8721b272de6a9..ea8cf45863af9 100644
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -299,21 +299,42 @@ typedef struct {
 } ggml_metal_kargs_mul_mv_ext;
 
 typedef struct {
-    int32_t  nei0;
-    int32_t  nei1;
-    uint64_t nbi1;
+    int32_t  ne10;
+    int32_t  ne11;  // n_expert_used (bcast)
+    uint64_t nb11;
+    uint64_t nb12;
+    int32_t  neh11; // n_tokens
+    uint64_t nbh11;
+    int32_t  ne20;  // n_expert_used
+    uint64_t nb21;
+} ggml_metal_kargs_mul_mm_id_map0;
+
+typedef struct {
+    int32_t  ne20; // n_expert_used
+    int32_t  neh0;
+    int32_t  neh1;
+    uint64_t nbh1;
+    uint64_t nbh2;
+    int32_t  ne0;
+    uint64_t nb1;
+    uint64_t nb2;
+} ggml_metal_kargs_mul_mm_id_map1;
+
+typedef struct {
     int32_t  ne00;
     int32_t  ne02;
     uint64_t nb01;
     uint64_t nb02;
-    int32_t  ne11;
-    int32_t  ne12;
-    int32_t  ne13;
-    uint64_t nb10;
-    uint64_t nb11;
-    uint64_t nb12;
-    int32_t  ne0;
-    int32_t  ne1;
+    uint64_t nb03;
+    int32_t  neh12;
+    uint64_t nbh10;
+    uint64_t nbh11;
+    uint64_t nbh12;
+    uint64_t nbh13;
+    int32_t  neh0;
+    int32_t  neh1;
+    int16_t  r2;
+    int16_t  r3;
 } ggml_metal_kargs_mul_mm_id;
 
 typedef struct {
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index d92392edb7eb1..943f0fe722a2f 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -306,28 +306,30 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
     GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
     GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
@@ -650,7 +652,8 @@ static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
     }
 
     if (mem_pool->heaps_to_remove.count > 0) {
-        for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) {
+        // remove in reverse order
+        for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) {
             NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
             ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
 
@@ -659,6 +662,10 @@ static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
 
             [mem_pool->heaps removeObjectAtIndex:index];
             [ptr release];
+
+            if (i == 0) {
+                break;
+            }
         }
 
         [mem_pool->heaps_to_remove removeAllObjects];
@@ -672,7 +679,7 @@ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
 }
 
 static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
-    const size_t alignment = 32;
+    const size_t alignment = 256;
 
     const size_t size_aligned = GGML_PAD(size, alignment);
 
@@ -1242,28 +1249,30 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,                mul_mm_iq1_m_f32,                has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,               mul_mm_iq4_nl_f32,               has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,               mul_mm_iq4_xs_f32,               has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,               mul_mm_id_f32_f32,               has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,               mul_mm_id_f16_f32,               has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,              mul_mm_id_bf16_f32,              has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,              mul_mm_id_q4_0_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,              mul_mm_id_q4_1_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,              mul_mm_id_q5_0_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,              mul_mm_id_q5_1_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,              mul_mm_id_q8_0_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,              mul_mm_id_q2_K_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,              mul_mm_id_q3_K_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,              mul_mm_id_q4_K_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,              mul_mm_id_q5_K_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,              mul_mm_id_q6_K_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,           mul_mm_id_iq2_xxs_f32,           has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,            mul_mm_id_iq2_xs_f32,            has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,           mul_mm_id_iq3_xxs_f32,           has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,             mul_mm_id_iq3_s_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,             mul_mm_id_iq2_s_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,             mul_mm_id_iq1_s_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,             mul_mm_id_iq1_m_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,            mul_mm_id_iq4_nl_f32,            has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,            mul_mm_id_iq4_xs_f32,            has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,              mul_mm_id_map0_f16,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,              mul_mm_id_map1_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,               mul_mm_id_f32_f16,               has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,               mul_mm_id_f16_f16,               has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,              mul_mm_id_bf16_f16,              has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16,              mul_mm_id_q4_0_f16,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16,              mul_mm_id_q4_1_f16,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,              mul_mm_id_q5_0_f16,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,              mul_mm_id_q5_1_f16,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,              mul_mm_id_q8_0_f16,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,              mul_mm_id_q2_K_f16,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,              mul_mm_id_q3_K_f16,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,              mul_mm_id_q4_K_f16,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16,              mul_mm_id_q5_K_f16,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16,              mul_mm_id_q6_K_f16,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16,           mul_mm_id_iq2_xxs_f16,           has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16,            mul_mm_id_iq2_xs_f16,            has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16,           mul_mm_id_iq3_xxs_f16,           has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16,             mul_mm_id_iq3_s_f16,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16,             mul_mm_id_iq2_s_f16,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16,             mul_mm_id_iq1_s_f16,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,             mul_mm_id_iq1_m_f16,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,            mul_mm_id_iq4_nl_f16,            has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,            mul_mm_id_iq4_xs_f16,            has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                   rope_norm_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                   rope_norm_f16,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                   rope_neox_f32,                   true);
@@ -2999,7 +3008,7 @@ static bool ggml_metal_encode_node(
                     [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
 
                     [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                    [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                 } else {
                     id<MTLComputePipelineState> pipeline = nil;
 
@@ -3219,8 +3228,6 @@ static bool ggml_metal_encode_node(
             } break;
         case GGML_OP_MUL_MAT_ID:
             {
-                const int n_as = src0->ne[2];
-
                 // src2 = ids
                 const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
 
@@ -3234,24 +3241,21 @@ static bool ggml_metal_encode_node(
                 GGML_ASSERT(ne03 == 1);
                 GGML_ASSERT(ne13 == 1);
 
+                const uint32_t r2 = 1;
+                const uint32_t r3 = 1;
+
                 // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                 // to the matrix-vector kernel
                 // ne20 = n_used_experts
-                // ne21 = n_rows
-                const int dst_rows = ne20*ne21;
-                const int dst_rows_min = n_as;
-                const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4;
-
-                // max size of the rowids array in the kernel shared buffer
-                //GGML_ASSERT(dst_rows <= dst_rows_max);
+                // ne21 = n_rows (batch size)
+                const int ne21_mm_id_min = 32;
 
                 // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                 // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
                 if ([device supportsFamily:MTLGPUFamilyApple7] &&
                         ne00 % 32 == 0 && ne00 >= 64 &&
-                        //ne01 / ne02 >= 512 &&    // NOTE: this is based on Mixtral shapes, might need adjustments
-                        dst_rows >  dst_rows_min &&
-                        dst_rows <= dst_rows_max) {
+                        (ne21 >= ne21_mm_id_min)) {
+                    GGML_ASSERT(ne00 % 4 == 0);
 
                     // some Metal matrix data types require aligned pointers
                     // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -3262,62 +3266,169 @@ static bool ggml_metal_encode_node(
                         default: break;
                     }
 
-                    id<MTLComputePipelineState> pipeline = nil;
+                    const int64_t neh10 = ne10; // n_embd
+                    const int64_t neh11 = ne21; // n_tokens
+                    const int64_t neh12 = ne02; // n_expert
 
-                    switch (src0->type) {
-                        case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32    ].pipeline; break;
-                        case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32    ].pipeline; break;
-                        case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32   ].pipeline; break;
-                        case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32   ].pipeline; break;
-                        case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32   ].pipeline; break;
-                        case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32   ].pipeline; break;
-                        case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32   ].pipeline; break;
-                        case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32   ].pipeline; break;
-                        case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32   ].pipeline; break;
-                        case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32   ].pipeline; break;
-                        case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32   ].pipeline; break;
-                        case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32   ].pipeline; break;
-                        case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32   ].pipeline; break;
-                        case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
-                        case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
-                        case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
-                        case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32  ].pipeline; break;
-                        case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32  ].pipeline; break;
-                        case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32  ].pipeline; break;
-                        case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32  ].pipeline; break;
-                        case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
-                        case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
-                        default: GGML_ABORT("MUL_MAT_ID not implemented");
+                    const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
+                    const uint64_t nbh11 = nbh10*neh10;
+                    const uint64_t nbh12 = nbh11*neh11;
+                    const uint64_t nbh13 = nbh12*neh12;
+
+                    const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
+                    id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
+                    if (!h_src1) {
+                        GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
+                        return false;
                     }
 
-                    ggml_metal_kargs_mul_mm_id args = {
-                        /*.nei0 =*/ ne20,
-                        /*.nei1 =*/ ne21,
-                        /*.nbi1 =*/ nb21,
-                        /*.ne00 =*/ ne00,
-                        /*.ne02 =*/ ne02,
-                        /*.nb01 =*/ nb01,
-                        /*.nb02 =*/ nb02,
-                        /*.ne11 =*/ ne11,
-                        /*.ne12 =*/ ne12,
-                        /*.ne13 =*/ ne13,
-                        /*.nb10 =*/ nb10,
-                        /*.nb11 =*/ nb11,
-                        /*.nb12 =*/ nb12,
-                        /*.ne0  =*/ ne0,
-                        /*.ne1  =*/ ne1,
-                    };
+                    const int64_t neh0 = ne0;
+                    const int64_t neh1 = ne21;
+                    const int64_t neh2 = ne02;
 
-                    [encoder setComputePipelineState:pipeline];
-                    [encoder setBytes:&args    length:sizeof(args) atIndex:0];
-                    [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
-                    [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
-                    [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
-                    [encoder setBuffer:id_src2 offset:offs_src2    atIndex:4];
+                    const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
+                    const uint64_t nbh1 = nbh0*neh0;
+                    const uint64_t nbh2 = nbh1*neh1;
+                  //const uint64_t nbh3 = nbh2*neh2;
+
+                    const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
+                    id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
+                    if (!h_dst) {
+                        GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
+                        return false;
+                    }
+
+                    // tokens per expert
+                    const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
+                    id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
+                    if (!h_tpe) {
+                        GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
+                        return false;
+                    }
+
+                    // id map
+                    // [n_expert_used, n_tokens]
+                    const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
+                    id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
+                    if (!h_ids) {
+                        GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
+                        return false;
+                    }
+
+                    {
+                        const int nth = MIN(1024, ne10/4);
+
+                        ggml_metal_kargs_mul_mm_id_map0 args = {
+                            ne10,
+                            ne11,  // n_expert_used (bcast)
+                            nb11,
+                            nb12,
+                            neh11, // n_tokens
+                            nbh11,
+                            ne20,  // n_expert_used
+                            nb21,
+                        };
+
+                        id<MTLComputePipelineState> pipeline = nil;
+
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                        [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
+                        [encoder setBuffer:id_src2 offset:offs_src2    atIndex:2];
+                        [encoder setBuffer: h_src1 offset:0            atIndex:3];
+                        [encoder setBuffer: h_tpe  offset:0            atIndex:4];
+                        [encoder setBuffer: h_ids  offset:0            atIndex:5];
+
+                        [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                    }
+
+                    {
+                        id<MTLComputePipelineState> pipeline = nil;
+
+                        switch (src0->type) {
+                            case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16    ].pipeline; break;
+                            case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16    ].pipeline; break;
+                            case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16   ].pipeline; break;
+                            case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16   ].pipeline; break;
+                            case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16   ].pipeline; break;
+                            case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16   ].pipeline; break;
+                            case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16   ].pipeline; break;
+                            case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16   ].pipeline; break;
+                            case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16   ].pipeline; break;
+                            case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16   ].pipeline; break;
+                            case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16   ].pipeline; break;
+                            case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16   ].pipeline; break;
+                            case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16   ].pipeline; break;
+                            case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break;
+                            case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break;
+                            case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break;
+                            case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16  ].pipeline; break;
+                            case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16  ].pipeline; break;
+                            case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16  ].pipeline; break;
+                            case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16  ].pipeline; break;
+                            case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
+                            case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
+                            default: GGML_ABORT("MUL_MAT_ID not implemented");
+                        }
+
+                        ggml_metal_kargs_mul_mm_id args = {
+                            /*.ne00  =*/ ne00,
+                            /*.ne02  =*/ ne02,
+                            /*.nb01  =*/ nb01,
+                            /*.nb02  =*/ nb02,
+                            /*.nb03  =*/ nb03,
+                            /*.neh12 =*/ neh12,
+                            /*.nbh10 =*/ nbh10,
+                            /*.nbh11 =*/ nbh11,
+                            /*.nbh12 =*/ nbh12,
+                            /*.nbh13 =*/ nbh13,
+                            /*.neh0  =*/ neh0,
+                            /*.neh1  =*/ neh1,
+                            /*.r2    =*/ r2,
+                            /*.r3    =*/ r3,
+                        };
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                        [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                        [encoder setBuffer: h_src1 offset:0            atIndex:2];
+                        [encoder setBuffer: h_tpe  offset:0            atIndex:3];
+                        [encoder setBuffer: h_dst  offset:0            atIndex:4];
+
+                        [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                    }
 
-                    [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
+                    {
+                        GGML_ASSERT(ne0 % 4 == 0);
 
-                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                        const int nth = MIN(1024, ne0/4);
+
+                        ggml_metal_kargs_mul_mm_id_map1 args = {
+                            ne20, // n_expert_used
+                            neh0,
+                            neh1,
+                            nbh1,
+                            nbh2,
+                            ne0,
+                            nb1,
+                            nb2,
+                        };
+
+                        id<MTLComputePipelineState> pipeline = nil;
+
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBytes:&args   length:sizeof(args) atIndex:0];
+                        [encoder setBuffer: h_dst offset:0            atIndex:1];
+                        [encoder setBuffer: h_ids offset:0            atIndex:2];
+                        [encoder setBuffer:id_dst offset:offs_dst     atIndex:3];
+
+                        [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                    }
                 } else {
                     id<MTLComputePipelineState> pipeline = nil;
 
@@ -3511,7 +3622,7 @@ static bool ggml_metal_encode_node(
                     [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
 
                     const int64_t _ne1 = 1;
-                    const int64_t ne123 = dst_rows;
+                    const int64_t ne123 = ne20*ne21;
 
                     if (smem > 0) {
                         [encoder setThreadgroupMemoryLength:smem atIndex:0];
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 9f4147e93974d..a808e79c22752 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -6336,127 +6336,219 @@ kernel void kernel_mul_mm(
     }
 }
 
-// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
-// TODO: this kernel needs to be reimplemented from scratch for better performance
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
-void kernel_mul_mm_id_impl(
-        int32_t  ne00,
-        int32_t  ne02,
-        uint64_t nb01,
-        uint64_t nb02,
-        int32_t  ne11,
-        int32_t  ne12,
-        uint64_t nb10,
-        uint64_t nb11,
-        uint64_t nb12,
-        int32_t  ne0,
-        int32_t  ne1,
-        int64_t  ne0ne1,
-        device   const char * src0,
-        device   const char * src1,
-        threadgroup ushort2 * rowids,
-        device         char * dst,
-        threadgroup    char * shmem,
+template<typename T4>
+kernel void kernel_mul_mm_id_map0(
+        constant ggml_metal_kargs_mul_mm_id_map0 & args,
+        device  const char * src1,
+        device  const char * src2,
+        device        char * hsrc1,
+        device        char * htpe,
+        device        char * hids,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int ide = tgpig[0]; // expert id
+
+    int n_all = 0;
+
+    device int32_t * ids_i32 = (device int32_t *) (hids);
+
+    for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
+        device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
+
+        for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
+            if (src2_i32[i20] != ide) {
+                continue;
+            }
+
+            device const float4 *  src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
+            device       T4     * hsrc1_f32x4 = (device       T4     *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
+
+            for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
+                hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
+            }
+
+            if (tpitg.x == 0) {
+                ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
+            }
+
+            ++n_all;
+        }
+    }
+
+    if (tpitg.x == 0) {
+        device int32_t * tpe_i32 = (device int32_t *) (htpe);
+        tpe_i32[ide] = n_all;
+    }
+}
+
+typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
+
+template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
+
+template<typename T>
+kernel void kernel_mul_mm_id_map1(
+        constant ggml_metal_kargs_mul_mm_id_map1 & args,
+        device  const char * hdst,
+        device  const char * hids,
+        device        char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i20 = tgpig[0]; // used expert
+    const int i21 = tgpig[1]; // token
+
+    device const int32_t * ids_i32    = (device const int32_t *) (hids);
+    device       float4  * dst_f32x4  = (device       float4  *) (dst + i20*args.nb1 + i21*args.nb2);
+
+    const int id = ids_i32[i21*args.ne20 + i20];
+
+    const int ide = id / args.neh1;
+    const int idt = id % args.neh1;
+
+    device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
+
+    for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
+        dst_f32x4[i0] = hdst_f32x4[i0];
+    }
+}
+
+typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
+
+template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
+
+template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
+kernel void kernel_mul_mm_id(
+        constant ggml_metal_kargs_mul_mm_id & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * tpe,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiitg[[thread_index_in_threadgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    threadgroup half  * sa = (threadgroup half  *)(shmem);
-    threadgroup float * sb = (threadgroup float *)(shmem + 4096);
+    threadgroup T    * sa = (threadgroup T    *)(shmem);
+    threadgroup half * sb = (threadgroup half *)(shmem + 4096);
 
     const int r0 = tgpig.y;
     const int r1 = tgpig.x;
+    const int im = tgpig.z;
+
+    device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
+
+    const int neh1 = tpe_i32[im];
 
-    if (r1*BLOCK_SIZE_N >= ne1) return;
+    if (r1*BLOCK_SIZE_N >= neh1) {
+        return;
+    }
 
     // if this block is of 64x32 shape or smaller
-    short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
-    short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+    const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
+    const short n_cols = (     neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (     neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
 
     // a thread shouldn't load data outside of the matrix
-    short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
-    short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+    const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+    const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
 
-    simdgroup_half8x8  ma[4];
-    simdgroup_float8x8 mb[2];
+    simdgroup_T8x8     ma[4];
+    simdgroup_half8x8  mb[2];
     simdgroup_float8x8 mc[8];
-    for (int i = 0; i < 8; i++){
+
+    for (short i = 0; i < 8; i++){
         mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
     }
+
     short il = (tiitg % THREAD_PER_ROW);
 
-    ushort offset1 = il/nl;
+    const int i12 = im%args.neh12;
+    const int i13 = im/args.neh12;
 
-    threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
+    const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const short    offset1 = il/nl;
 
-    device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
-    device const float   * y = (device const float   *)(src1
-        + nb12 * id[1]
-        + nb11 * (id[0] % ne11)
-        + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+    device const block_q * x = (device const block_q *)(src0
+        + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
 
-    for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+    device const half   * y = (device const half   *)(src1
+        + args.nbh13*i13
+        + args.nbh12*i12
+        + args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
+        + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+    for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
         // load data and store to threadgroup memory
-        half4x4 temp_a;
+        T4x4 temp_a;
         dequantize_func(x, il, temp_a);
+
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
-        for (int i = 0; i < 16; i++) {
-            *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
-            +                     (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
-            +                     (tiitg / THREAD_PER_ROW) % 8  + (i & 7) * 8) = temp_a[i/4][i%4];
+        #pragma unroll(16)
+        for (short i = 0; i < 16; i++) {
+            *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
+            +                     (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
+            +                     (tiitg/THREAD_PER_ROW)%8  + (i&7)*8) = temp_a[i/4][i%4];
         }
 
-        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+        *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
 
         il = (il + 2 < nl) ? il + 2 : il % 2;
-        x  = (il < 2) ? x + (2+nl-1)/nl : x;
+        x  = (il < 2) ? x + (2 + nl - 1)/nl : x;
         y += BLOCK_SIZE_K;
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
         // load matrices from threadgroup memory and conduct outer products
-        threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
-        threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+        threadgroup const T    * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
+        threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
 
-        #pragma unroll(BLOCK_SIZE_K/8)
-        for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+        #pragma unroll(4)
+        for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
             #pragma unroll(4)
-            for (int i = 0; i < 4; i++) {
+            for (short i = 0; i < 4; i++) {
                 simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
             }
+
             simdgroup_barrier(mem_flags::mem_none);
+
             #pragma unroll(2)
-            for (int i = 0; i < 2; i++) {
+            for (short i = 0; i < 2; i++) {
                 simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
             }
 
-            lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
-            lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
-
             #pragma unroll(8)
-            for (int i = 0; i < 8; i++){
+            for (short i = 0; i < 8; i++){
                 simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
             }
+
+            lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
+            lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
         }
     }
 
-    {
+    if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
+        device float * C = (device float *) dst +
+            (BLOCK_SIZE_M * r0 + 32*(sgitg &  1)) + \
+            (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
+
+        for (short i = 0; i < 8; i++) {
+            simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
+        }
+    } else {
+        // block is smaller than 64x32, we should avoid writing data outside of the matrix
         threadgroup_barrier(mem_flags::mem_threadgroup);
         threadgroup float * temp_str = ((threadgroup float *) shmem) \
-                                      + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
-        for (int i = 0; i < 8; i++) {
-            simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+                                     + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
+        for (short i = 0; i < 8; i++) {
+            simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
         }
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
         if (sgitg == 0) {
             for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
-                threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
-                int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
-
-                device float  * D  = (device float  *) dst + (r0*BLOCK_SIZE_M) + joff;
+                device float  * D  = (device float  *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
                 device float4 * D4 = (device float4 *) D;
 
                 threadgroup float  * C  = temp_str + (j*BLOCK_SIZE_M);
@@ -6476,66 +6568,6 @@ void kernel_mul_mm_id_impl(
     }
 }
 
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
-kernel void kernel_mul_mm_id(
-        constant ggml_metal_kargs_mul_mm_id & args,
-        device const char * src0s,
-        device const char * src1,
-        device       char * dst,
-        device const char * ids,
-        threadgroup  char * shmem [[threadgroup(0)]],
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        ushort tiitg[[thread_index_in_threadgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    const int32_t i02 = tgpig.z;
-
-    tgpig.z = 0;
-
-    device const char * src0 = src0s + i02*args.nb02;
-
-    // row indices
-    threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
-
-    // TODO: parallelize this loop
-    int32_t _ne1 = 0;
-    for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
-        for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
-            int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
-            if (id == i02) {
-                if (tiitg == 0) {
-                    rowids[_ne1] = ushort2(ii0, ii1);
-                }
-                _ne1++;
-            }
-        }
-    }
-
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
-        args.ne00,
-        args.ne02,
-        args.nb01,
-        args.nb02,
-        args.ne11,
-        args.ne12,
-        args.nb10,
-        args.nb11,
-        args.nb12,
-        args.ne0,
-        _ne1,
-        (int64_t)args.ne0*args.ne1,
-        src0,
-        src1,
-        rowids,
-        dst,
-        shmem,
-        tgpig,
-        tiitg,
-        sgitg);
-}
-
 #define QK_NL 16
 
 //
@@ -6576,63 +6608,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_q_t kernel_get
 // matrix-matrix multiplication
 //
 
-typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
+typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_t;
 
-template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32>;
-template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16>;
+template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32>;
+template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16>;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mm_bf16_f32")]]    kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4,     1,     dequantize_bf16>;
+template [[host_name("kernel_mul_mm_bf16_f32")]]    kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4,     1,     dequantize_bf16>;
 #endif
-template [[host_name("kernel_mul_mm_q4_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0>;
-template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1>;
-template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0>;
-template [[host_name("kernel_mul_mm_q5_1_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1>;
-template [[host_name("kernel_mul_mm_q8_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0>;
-template [[host_name("kernel_mul_mm_q2_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_mul_mm_q3_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_mul_mm_q4_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_mul_mm_q5_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_mul_mm_q6_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_mul_mm_iq2_s_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_mul_mm_iq1_m_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl>;
-template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
+template [[host_name("kernel_mul_mm_q4_0_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0>;
+template [[host_name("kernel_mul_mm_q5_1_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1>;
+template [[host_name("kernel_mul_mm_q8_0_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_q2_K_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_q3_K_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_q4_K_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_q5_K_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_q6_K_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
+template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s>;
+template [[host_name("kernel_mul_mm_iq2_s_f32")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s>;
+template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_mul_mm_iq1_m_f32")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl>;
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
 
 //
 // indirect matrix-matrix multiplication
 //
 
-typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
+typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_id;
 
-template [[host_name("kernel_mul_mm_id_f32_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id<float4x4,      1,     dequantize_f32>;
-template [[host_name("kernel_mul_mm_id_f16_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id<half4x4,       1,     dequantize_f16>;
+template [[host_name("kernel_mul_mm_id_f32_f16")]]     kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32>;
+template [[host_name("kernel_mul_mm_id_f16_f16")]]     kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16>;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mm_id_bf16_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4,     1,     dequantize_bf16>;
+template [[host_name("kernel_mul_mm_id_bf16_f16")]]    kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4,     1,     dequantize_bf16>;
 #endif
-template [[host_name("kernel_mul_mm_id_q4_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0,    2,     dequantize_q4_0>;
-template [[host_name("kernel_mul_mm_id_q4_1_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1,    2,     dequantize_q4_1>;
-template [[host_name("kernel_mul_mm_id_q5_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0,    2,     dequantize_q5_0>;
-template [[host_name("kernel_mul_mm_id_q5_1_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1,    2,     dequantize_q5_1>;
-template [[host_name("kernel_mul_mm_id_q8_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0,    2,     dequantize_q8_0>;
-template [[host_name("kernel_mul_mm_id_q2_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K,    QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_mul_mm_id_q3_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K,    QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_mul_mm_id_q4_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K,    QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_mul_mm_id_q5_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K,    QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_mul_mm_id_q6_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K,    QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_mul_mm_id_iq3_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s,   QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_mul_mm_id_iq2_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s,   QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_mul_mm_id_iq1_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s,   QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_mul_mm_id_iq1_m_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m,   QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl,  2,     dequantize_iq4_nl>;
-template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
+template [[host_name("kernel_mul_mm_id_q4_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_id_q4_1_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_id_q5_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0>;
+template [[host_name("kernel_mul_mm_id_q5_1_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1>;
+template [[host_name("kernel_mul_mm_id_q8_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_id_q2_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_id_q3_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_id_q4_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_id_q5_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_id_q6_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
+template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_mul_mm_id_iq3_s_f16")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s>;
+template [[host_name("kernel_mul_mm_id_iq2_s_f16")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s>;
+template [[host_name("kernel_mul_mm_id_iq1_s_f16")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_mul_mm_id_iq1_m_f16")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl>;
+template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
+
 
 //
 // matrix-vector multiplication
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index ee4fe9f723dc1..bc673292b37a3 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -2732,11 +2732,11 @@ void ggml_mul_mat_set_prec(
     c = ggml_mul_mat_id(ctx, as, b, ids);
 
     as  -> [cols, rows, n_expert]
-    ids -> [n_experts_used, n_tokens] (i32)
     b   -> [cols, n_expert_used, n_tokens]
+    ids -> [n_expert_used, n_tokens] (i32)
     c   -> [rows, n_expert_used, n_tokens]
 
-    in b, n_experts_used can be broadcasted to match the n_expert_used of ids
+    in b, n_expert_used can be broadcasted to match the n_expert_used of ids
 
     c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
 */