Skip to content

Commit 5eaf262

Browse files
committed
q80*q40.
1 parent 0493606 commit 5eaf262

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

src/llamafile-sgemm.cpp

+16-5
Original file line numberDiff line numberDiff line change
@@ -458,12 +458,12 @@ class tinyBLAS {
458458
// QUANT ZERO MATRIX MULTIPLICATION
459459

460460
#if defined(__ARM_FEATURE_DOTPROD)
461-
template <typename TA>
461+
template <typename TA, typename TB>
462462
class tinyBLAS_Q0_ARM {
463463
public:
464464
tinyBLAS_Q0_ARM(int64_t k,
465465
const TA *A, int64_t lda,
466-
const BlockQ80 *B, int64_t ldb,
466+
const TB *B, int64_t ldb,
467467
float *C, int64_t ldc,
468468
int ith, int nth)
469469
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
@@ -584,7 +584,7 @@ class tinyBLAS_Q0_ARM {
584584
}
585585

586586
const TA *const A;
587-
const BlockQ80 *const B;
587+
const TB *const B;
588588
float *const C;
589589
const int64_t k;
590590
const int64_t lda;
@@ -936,6 +936,17 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
936936
}
937937

938938
case Q80: {
939+
#if defined(__ARM_FEATURE_DOTPROD)
940+
if (Btype == Q40) {
941+
tinyBLAS_Q0_ARM<BlockQ80, BlockQ40> tb{
942+
k, (const BlockQ80 *)A, lda,
943+
(const BlockQ40 *)B, ldb,
944+
(float *)C, ldc,
945+
ith, nth};
946+
tb.matmul(m, n, task);
947+
return true;
948+
}
949+
#endif
939950
if (Btype != Q80)
940951
return false;
941952
#if defined(__AVX2__) || defined(__AVX512F__)
@@ -947,7 +958,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
947958
tb.matmul(m, n, task);
948959
return true;
949960
#elif defined(__ARM_FEATURE_DOTPROD)
950-
tinyBLAS_Q0_ARM<BlockQ80> tb{
961+
tinyBLAS_Q0_ARM<BlockQ80, BlockQ80> tb{
951962
k, (const BlockQ80 *)A, lda,
952963
(const BlockQ80 *)B, ldb,
953964
(float *)C, ldc,
@@ -971,7 +982,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
971982
tb.matmul(m, n, task);
972983
return true;
973984
#elif defined(__ARM_FEATURE_DOTPROD)
974-
tinyBLAS_Q0_ARM<BlockQ40> tb{
985+
tinyBLAS_Q0_ARM<BlockQ40, BlockQ80> tb{
975986
k, (const BlockQ40 *)A, lda,
976987
(const BlockQ80 *)B, ldb,
977988
(float *)C, ldc,

0 commit comments

Comments
 (0)