Skip to content

Commit bb6d00b

Browse files
authored
metal : move mm_id indices to shared mem (ggml-org#5982)
1 parent 7ab7b73 commit bb6d00b

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

ggml-metal.m

+3-3
Original file line numberDiff line numberDiff line change
@@ -1642,8 +1642,8 @@ static enum ggml_status ggml_metal_graph_compute(
16421642
// TODO: make this more general
16431643
GGML_ASSERT(n_as <= 8);
16441644

1645-
// max size of the src1ids array in the kernel stack
1646-
GGML_ASSERT(ne11 <= 512);
1645+
// max size of the src1ids array in the kernel shared buffer
1646+
GGML_ASSERT(ne11 <= 4096);
16471647

16481648
const int64_t ne20 = src2 ? src2->ne[0] : 0;
16491649
const int64_t ne21 = src2 ? src2->ne[1] : 0;
@@ -1741,7 +1741,7 @@ static enum ggml_status ggml_metal_graph_compute(
17411741
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
17421742
}
17431743

1744-
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
1744+
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
17451745

17461746
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
17471747
} else {

ggml-metal.metal

+3-3
Original file line numberDiff line numberDiff line change
@@ -5386,7 +5386,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
53865386
void kernel_mul_mm_id_impl(
53875387
device const uchar * src0,
53885388
device const uchar * src1,
5389-
thread short * src1ids,
5389+
threadgroup short * src1ids,
53905390
device float * dst,
53915391
constant int64_t & ne00,
53925392
constant int64_t & ne02,
@@ -5589,9 +5589,9 @@ kernel void kernel_mul_mm_id(
55895589
tgpig.z = tgpig.z%(ne12*ne13);
55905590

55915591
// row indices of src1 for expert id
5592-
int64_t _ne1 = 0;
5593-
short src1ids[512];
5592+
threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);
55945593

5594+
int64_t _ne1 = 0;
55955595
for (int64_t i1 = 0; i1 < ne1; i1++) {
55965596
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
55975597
src1ids[_ne1++] = i1;

0 commit comments

Comments
 (0)