Skip to content

Commit 30ae0d2

Browse files
committed
tinyblas dynamic dispaching
1 parent f48c35d commit 30ae0d2

File tree

4 files changed

+140
-57
lines changed

4 files changed

+140
-57
lines changed

examples/server/tests/unit/test_completion.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_completion_stream_vs_non_stream():
7878
@pytest.mark.parametrize("n_slots", [1, 2])
7979
def test_consistent_result_same_seed(n_slots: int):
8080
global server
81-
server.n_slots = 1
81+
server.n_slots = n_slots
8282
server.start()
8383
last_res = None
8484
for _ in range(4):
@@ -115,7 +115,7 @@ def test_different_result_different_seed(n_slots: int):
115115
@pytest.mark.parametrize("temperature", [0.0, 1.0])
116116
def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
117117
global server
118-
server.n_batch = 1
118+
server.n_batch = n_batch
119119
server.start()
120120
last_res = None
121121
for _ in range(4):

ggml/src/ggml-cpu/ggml-cpu.c

+4-4
Original file line numberDiff line numberDiff line change
@@ -7420,14 +7420,14 @@ static void ggml_compute_forward_mul_mat(
74207420
if (src1_cont) {
74217421
for (int64_t i13 = 0; i13 < ne13; i13++)
74227422
for (int64_t i12 = 0; i12 < ne12; i12++)
7423-
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
7423+
if (!llamafile_sgemm(params,
7424+
ne01, ne11, ne00/ggml_blck_size(src0->type),
74247425
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
74257426
nb01/ggml_type_size(src0->type),
74267427
(const char *)src1->data + i12*nb12 + i13*nb13,
74277428
nb11/ggml_type_size(src1->type),
74287429
(char *)dst->data + i12*nb2 + i13*nb3,
74297430
nb1/ggml_type_size(dst->type),
7430-
ith, nth,
74317431
src0->type,
74327432
src1->type,
74337433
dst->type))
@@ -7472,14 +7472,14 @@ UseGgmlGemm1:;
74727472

74737473
for (int64_t i13 = 0; i13 < ne13; i13++)
74747474
for (int64_t i12 = 0; i12 < ne12; i12++)
7475-
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
7475+
if (!llamafile_sgemm(params,
7476+
ne01, ne11, ne00/ggml_blck_size(src0->type),
74767477
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
74777478
nb01/ggml_type_size(src0->type),
74787479
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
74797480
row_size/ggml_type_size(vec_dot_type),
74807481
(char *)dst->data + i12*nb2 + i13*nb3,
74817482
nb1/ggml_type_size(dst->type),
7482-
ith, nth,
74837483
src0->type,
74847484
vec_dot_type,
74857485
dst->type))

0 commit comments

Comments
 (0)