@@ -2790,10 +2790,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
2790
2790
(char *)output_buffer + batch1 * output_stride, ACL_FLOAT16,
2791
2791
output_elem_size, output_ne, output_nb, 2 , ACL_FORMAT_ND,
2792
2792
output_ne_offset);
2793
+ int64_t antiquantGroupSize = 0 ;
2794
+ if (src0->ne [0 ] > QK8_0) {
2795
+ antiquantGroupSize = QK8_0;
2796
+ }
2793
2797
2794
2798
ACL_CHECK (aclnnWeightQuantBatchMatmulV2GetWorkspaceSize (
2795
2799
acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr ,
2796
- nullptr , nullptr , nullptr , QK8_0 , acl_output_tensor,
2800
+ nullptr , nullptr , nullptr , antiquantGroupSize , acl_output_tensor,
2797
2801
&workspaceSize, &executor));
2798
2802
if (workspaceAddr == nullptr ) {
2799
2803
workspaceAddr = workspace_allocator.alloc (workspaceSize);
@@ -2833,7 +2837,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
2833
2837
2834
2838
ACL_CHECK (aclnnWeightQuantBatchMatmulV2GetWorkspaceSize (
2835
2839
acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
2836
- nullptr , nullptr , nullptr , nullptr , QK8_0 ,
2840
+ nullptr , nullptr , nullptr , nullptr , antiquantGroupSize ,
2837
2841
acl_output_tensor, &workspaceSize, &executor));
2838
2842
ACL_CHECK (aclnnWeightQuantBatchMatmulV2 (
2839
2843
workspaceAddr, workspaceSize, executor, ctx.stream ()));
0 commit comments