Skip to content

Commit 192090b

Browse files
authored
llamafile : improve sgemm.cpp (ggml-org#6796)
* llamafile : improve sgemm.cpp - Re-enable by default - Fix issue described in ggml-org#6716 - Make code more abstract, elegant, and maintainable - Faster handling of weirdly shaped `m` an `n` edge cases * Address review comments * Help clang produce fma instructions * Address review comments
1 parent e931888 commit 192090b

File tree

4 files changed

+412
-573
lines changed

4 files changed

+412
-573
lines changed

CMakeLists.txt

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,11 @@ else()
4343
set(LLAMA_METAL_DEFAULT OFF)
4444
endif()
4545

46-
# TODO: fix this for Android CI
47-
# https://github.com/ggerganov/llama.cpp/pull/6716#issuecomment-2061509191
48-
#if (CMAKE_SYSTEM_NAME MATCHES "ANDROID")
49-
# set(LLAMA_LLAMAFILE_DEFAULT OFF)
50-
#else()
51-
# set(LLAMA_LLAMAFILE_DEFAULT ON)
52-
#endif()
53-
54-
# TODO: temporary disable until MoE is fixed
55-
# https://github.com/ggerganov/llama.cpp/pull/6716
56-
set(LLAMA_LLAMAFILE_DEFAULT OFF)
46+
if (CMAKE_SYSTEM_NAME MATCHES "ANDROID")
47+
set(LLAMA_LLAMAFILE_DEFAULT OFF)
48+
else()
49+
set(LLAMA_LLAMAFILE_DEFAULT ON)
50+
endif()
5751

5852
# general
5953
option(BUILD_SHARED_LIBS "build shared libraries" OFF)

Makefile

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -384,10 +384,6 @@ ifdef LLAMA_OPENBLAS
384384
MK_LDFLAGS += $(shell pkg-config --libs openblas)
385385
endif # LLAMA_OPENBLAS
386386

387-
# TODO: temporary disable until MoE is fixed
388-
# https://github.com/ggerganov/llama.cpp/pull/6716
389-
LLAMA_NO_LLAMAFILE := 1
390-
391387
ifndef LLAMA_NO_LLAMAFILE
392388
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
393389
OBJS += sgemm.o

ggml.c

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10825,7 +10825,7 @@ static void ggml_compute_forward_mul_mat(
1082510825
#endif
1082610826

1082710827
#if GGML_USE_LLAMAFILE
10828-
if (nb10 == ggml_type_size(src1->type)) {
10828+
if (src1_cont) {
1082910829
for (int64_t i13 = 0; i13 < ne13; i13++)
1083010830
for (int64_t i12 = 0; i12 < ne12; i12++)
1083110831
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
@@ -10878,15 +10878,13 @@ UseGgmlGemm1:;
1087810878
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
1087910879

1088010880
#if GGML_USE_LLAMAFILE
10881-
if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) {
10881+
if (src1->type != vec_dot_type) {
1088210882
for (int64_t i13 = 0; i13 < ne13; i13++)
1088310883
for (int64_t i12 = 0; i12 < ne12; i12++)
1088410884
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
1088510885
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
1088610886
nb01/ggml_type_size(src0->type),
10887-
(const char *)wdata + ggml_row_size(vec_dot_type,
10888-
nb12/ggml_type_size(src1->type)*i12 +
10889-
nb13/ggml_type_size(src1->type)*i13),
10887+
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
1089010888
row_size/ggml_type_size(vec_dot_type),
1089110889
(char *)dst->data + i12*nb2 + i13*nb3,
1089210890
nb1/ggml_type_size(dst->type),

0 commit comments

Comments
 (0)