diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 5250c3baef2..ba1f50a23c1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -404,6 +404,21 @@ void add_conv2d_node( wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; } + utils::uvec3 local_wg_size; + if (method == Conv2dMethod::Pointwise) { + uint32_t local_wg_size_y = 1; + if (wg_size[1] % 8 == 0) { + local_wg_size_y = 8; + } else if (wg_size[1] % 4 == 0) { + local_wg_size_y = 4; + } else if (wg_size[1] % 2 == 0) { + local_wg_size_y = 2; + } + local_wg_size = {64 / local_wg_size_y, local_wg_size_y, 1}; + } else { + local_wg_size = graph.create_local_wg_size(wg_size); + } + vkapi::ParamsBindList param_buffers; std::vector push_constants; if (method == Conv2dMethod::Pointwise) { @@ -464,7 +479,7 @@ void add_conv2d_node( graph, shader, wg_size, - graph.create_local_wg_size(wg_size), + local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, // Shader params buffers