From 4c07ef278632ea3831331c4506aa49c47fe8dea6 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Fri, 30 May 2025 10:16:04 -0700 Subject: [PATCH 1/2] [ET-VK] Modifying should_squeeze function in SqueezeUnsqueezeInputs to not squeeze if significant axis are all 1 and trailing axis are all > 1. Pull Request resolved: https://github.com/pytorch/executorch/pull/11177 This diff modifies the `should_squeeze` function in `SqueezeUnsqueezeInputs` to not squeeze (return False) if significant axes are all 1 and trailing axes are all > 1. ghstack-source-id: 287222796 @exported-using-ghexport Differential Revision: [D75483587](https://our.internmc.facebook.com/intern/diff/D75483587/) --- backends/vulkan/_passes/squeeze_unsqueeze_inputs.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py b/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py index b4337829d7f..c415249383e 100644 --- a/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py +++ b/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py @@ -32,7 +32,13 @@ def should_squeeze(self, op, shape: List[int]) -> bool: # pyre-ignore return shape[1] == 1 and shape[0] > 1 if len(shape) == 4: # No need to squeeze if all dims are 1 except the width dim - if all(dim == 1 for dim in shape[:-1]): + if shape[0] == shape[1] == shape[2] == 1: + return False + # No need to squeeze if batch and channel dims are 1 and height and width are > 1 + if shape[0] == shape[1] == 1 and shape[2] > 1 and shape[3] > 1: + return False + # No need to squeeze if batch dim is 1 and channel, height and width are > 1 + if shape[0] == 1 and shape[1] > 1 and shape[2] > 1 and shape[3] > 1: return False # Otherwise, check for squeezable dim return 1 in shape[:-1] From 7099bcdad6e36dc4381187a84b4f291ec765c804 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Fri, 30 May 2025 10:16:06 -0700 Subject: [PATCH 2/2] [ET-VK] Removed shared memory usage and simplied conv2d dw op shader to improve performance. Pull Request resolved: https://github.com/pytorch/executorch/pull/11178 This diff removes shared memory usage in `conv2d_dw_output_tile.glsl` shader to improve performance. Makes sum a one dimensional array, and moves bias application before storing texel. ghstack-source-id: 287222799 @exported-using-ghexport Differential Revision: [D75499165](https://our.internmc.facebook.com/intern/diff/D75499165/) --- .../graph/ops/glsl/conv2d_dw_output_tile.glsl | 28 ++++++------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index 3265a973980..0ee19206f59 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -47,11 +47,6 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -// For performance improvement, reduce register usage by caching positions in shared memory. -// Offset index by 1 every 16 points to avoid bank access conflict. -#define offset_pos_index(index) (index + ((index) >> 4)) -shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE)]; - /* * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. @@ -77,8 +72,6 @@ void main() { return; } - pos_shared[offset_pos_index(gl_LocalInvocationIndex)] = pos; - // Compute the index of the top-left element of the overlay region. Negative // indices indicate that the top-left element is in a region added by padding. const ivec2 ipos = pos.xy * stride - padding; @@ -89,13 +82,10 @@ void main() { const ivec2 end = ipos + overlay_region.xy; // sum outputs - VEC4_T sum[BATCH_SIZE_Y][BATCH_SIZE_X]; + VEC4_T sum[BATCH_SIZE_Y * BATCH_SIZE_X]; - sum[0][0] = texelFetch(t_bias, ivec2(pos.z, 0), 0); - for (int y = 0; y < BATCH_SIZE_Y; y++) { - for (int x = 0; x < BATCH_SIZE_X; x++) { - sum[y][x] = sum[0][0]; - } + for (int i = 0; i < BATCH_SIZE_Y * BATCH_SIZE_X; i++) { + sum[i] = VEC4_T(0); } // array to store input texels @@ -115,7 +105,7 @@ void main() { if (i > 0) { for (int j = 0; j < TILE_SIZE; j++) { for (int s = 0; s < BATCH_SIZE_X; s++) { - sum[1][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[1][s]); + sum[BATCH_SIZE_X + s] = fma(in_texels[j + s], prev_kernel_line[j], sum[BATCH_SIZE_X + s]); } } } @@ -125,19 +115,19 @@ void main() { for (int j = 0; j < TILE_SIZE; j++, kx++) { prev_kernel_line[j] = texelFetch(t_kernel, ivec2(kx, pos.z), 0); for (int s = 0; s < BATCH_SIZE_X; s++) { - sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]); + sum[s] = fma(in_texels[j + s], prev_kernel_line[j], sum[s]); } } } } - const ivec3 out_pos = pos_shared[offset_pos_index(gl_LocalInvocationIndex)]; + const VEC4_T bias = texelFetch(t_bias, ivec2(pos.z, 0), 0); for (int y = 0; y < BATCH_SIZE_Y; y++) { for (int x = 0; x < BATCH_SIZE_X; x++) { - if (any(greaterThanEqual(ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), out_limits.xyz))) { - continue; + const ivec3 out_pos = ivec3(pos.x + x, pos.y + y, pos.z); + if (all(lessThan(out_pos.xy, out_limits.xy))) { + imageStore(t_out, out_pos, op(sum[y * BATCH_SIZE_X + x] + bias, out_min, out_max)); } - imageStore(t_out, ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), op(sum[y][x], out_min, out_max)); } } }