Skip to content

Commit 92a3913

Browse files
authored
[CANN]MUL_MAT optimization (ggml-org#12382)
1 parent 9f2250b commit 92a3913

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -2790,10 +2790,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
27902790
(char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
27912791
output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND,
27922792
output_ne_offset);
2793+
int64_t antiquantGroupSize = 0;
2794+
if (src0->ne[0] > QK8_0) {
2795+
antiquantGroupSize = QK8_0;
2796+
}
27932797

27942798
ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
27952799
acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
2796-
nullptr, nullptr, nullptr, QK8_0, acl_output_tensor,
2800+
nullptr, nullptr, nullptr, antiquantGroupSize, acl_output_tensor,
27972801
&workspaceSize, &executor));
27982802
if (workspaceAddr == nullptr) {
27992803
workspaceAddr = workspace_allocator.alloc(workspaceSize);
@@ -2833,7 +2837,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
28332837

28342838
ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
28352839
acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
2836-
nullptr, nullptr, nullptr, nullptr, QK8_0,
2840+
nullptr, nullptr, nullptr, nullptr, antiquantGroupSize,
28372841
acl_output_tensor, &workspaceSize, &executor));
28382842
ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
28392843
workspaceAddr, workspaceSize, executor, ctx.stream()));

ggml/src/ggml-cann/ggml-cann.cpp

-5
Original file line numberDiff line numberDiff line change
@@ -1689,11 +1689,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
16891689
case GGML_OP_MUL_MAT: {
16901690
switch (op->src[0]->type) {
16911691
case GGML_TYPE_Q8_0:
1692-
// Current groupsize should not be greater than k-1 in
1693-
// aclnnWeightQuantBatchMatmulV2GetWorkspaceSize
1694-
if (op->src[0]->ne[0] <= QK8_0) {
1695-
return false;
1696-
}
16971692
case GGML_TYPE_F16:
16981693
case GGML_TYPE_F32:
16991694
case GGML_TYPE_Q4_0:

0 commit comments

Comments
 (0)