Skip to content

[ET-VK] Creating specialized version of conv2d pw shader for X and Y stride = 1 and padding = 0. #11190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 55 additions & 41 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

#define TILE_SIZE_X ${TILE_SIZE_X}
#define TILE_SIZE_Y ${TILE_SIZE_Y}
#define LOCAL_WG_SIZE 64

#define op(X, A, B) ${OPERATOR}

Expand All @@ -39,59 +38,61 @@ layout(push_constant) uniform restrict Block {

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

// For performance improvement, reduce register usage by caching positions in shared memory.
// Offset index by 1 every 16 points to avoid bank access conflict.
#define offset_pos_index(index) (index + ((index) >> 4))
shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE * TILE_SIZE_X * TILE_SIZE_Y)];
#extension GL_EXT_control_flow_attributes : require

/*
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
* output tile for pointwise convolution is more efficient because the kernel
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
*/
void main() {
const ivec2 out_limits_scaled = (out_limits.xy + ivec2(TILE_SIZE_X - 1, TILE_SIZE_Y - 1)) / ivec2(TILE_SIZE_X, TILE_SIZE_Y);
const uint shared_mem_stride = LOCAL_WG_SIZE;
const int out_limits_scaled[2] = {out_limits.x + (TILE_SIZE_X - 1) * TILE_SIZE_X, out_limits.y + (TILE_SIZE_Y - 1) * TILE_SIZE_Y};

const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
const ivec3 gpos = ivec3(
gl_GlobalInvocationID.x % out_limits_scaled.x,
div_by_x % out_limits_scaled.y,
div_by_x / out_limits_scaled.y);
const int div_by_x = int(gl_GlobalInvocationID.x / out_limits_scaled[0]);
const int out_pos[3] = {int(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x, int(gl_GlobalInvocationID.y)};

// If the top left position is out of bounds, then this invocation will have
// no work to do.
if (out_pos[1] >= out_limits_scaled[1] || out_pos[2] >= out_limits.z) {
return;
}

// Output position for TILE_SIZE = 2
// +--------+--------+
// | pos[0] | pos[1] |
// +--------+--------+
// | pos[2] | pos[3] |
// +--------+--------+
ivec2 pos[TILE_SIZE_X * TILE_SIZE_Y];
int pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
for (int x = 0; x < TILE_SIZE_X; ++x) {
pos[i] = ivec2(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y);
pos_shared[offset_pos_index((shared_mem_stride * i) + gl_LocalInvocationIndex)] = ivec3(pos[i], gpos.z);
pos[i * 2] = out_pos[0] * TILE_SIZE_X + x;
pos[i * 2 + 1] = out_pos[1] * TILE_SIZE_Y + y;
i++;
}
}

// If the top left position is out of bounds, then this invocation will have
// no work to do.
if (gpos.z >= out_limits.z) {
return;
}

// Compute the index of the input texture that needs to be loaded for each
// output position. Note that negative indices can be produced indicating that
// the top-left element is in a region added by padding.
ivec2 ipos[TILE_SIZE_X * TILE_SIZE_Y];
int ipos[TILE_SIZE_X * TILE_SIZE_Y * 2];
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
ipos[i] = pos[i] * stride - padding;
ipos[i * 2] = pos[i * 2] * stride.x - padding.x;
ipos[i * 2 + 1] = pos[i * 2 + 1] * stride.y - padding.y;
}

vec4 sum[TILE_SIZE_X * TILE_SIZE_Y];
sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
for (int i = 1; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
sum[i] = sum[0];
// Final output array where each element is a tensor value.
// Tuple of consecutive 4 elements represents a single output texel.
float sum[TILE_SIZE_X * TILE_SIZE_Y * 4];

const vec4 bias = texelFetch(t_bias, ivec2(out_pos[2], 0), 0);

// Initialize the output array with the bias value
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i += 4) {
sum[i] = bias.x;
sum[i + 1] = bias.y;
sum[i + 2] = bias.z;
sum[i + 3] = bias.w;
}

int z4 = 0;
Expand All @@ -100,14 +101,26 @@ void main() {
// During prepacking, the weight tensor has been permuted so that the
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
// the z-axis.
const vec4 ktex_0 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(0, 0));
const vec4 ktex_1 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(1, 0));
const vec4 ktex_2 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(2, 0));
const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(3, 0));
float kernel_values[4 * 4]; // 4 channels, 4 elements per channel

// Load kernel values from texels to array
[[unroll]] for (int i = 0; i < 4; ++i) {
const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos[2]), 0);
kernel_values[i * 4 + 0] = k_tex.x;
kernel_values[i * 4 + 1] = k_tex.y;
kernel_values[i * 4 + 2] = k_tex.z;
kernel_values[i * 4 + 3] = k_tex.w;
}

#pragma unroll
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0);
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i * 2], ipos[i * 2 + 1], z4), 0);
// Load the input texel into an array
float tex_values[4];
tex_values[0] = in_tex.x;
tex_values[1] = in_tex.y;
tex_values[2] = in_tex.z;
tex_values[3] = in_tex.w;

// For 2x2 tile size algorithm works as follows.
// To explain the calculations below, the contents of one in_tex and the
// group of 4 texels loaded from t_kernel are shown:
Expand Down Expand Up @@ -141,18 +154,19 @@ void main() {
//
// which is what is expressed in the following calculations. This is done
// for each output position.
sum[i] = fma(in_tex.xxxx, ktex_0, sum[i]);
sum[i] = fma(in_tex.yyyy, ktex_1, sum[i]);
sum[i] = fma(in_tex.zzzz, ktex_2, sum[i]);
sum[i] = fma(in_tex.wwww, ktex_3, sum[i]);
for (int j = 0; j < 4; ++j) {
sum[i * 4 + j] = tex_values[0] * kernel_values[0 + j] + sum[i * 4 + j];
sum[i * 4 + j] = tex_values[1] * kernel_values[4 + j] + sum[i * 4 + j];
sum[i * 4 + j] = tex_values[2] * kernel_values[8 + j] + sum[i * 4 + j];
sum[i * 4 + j] = tex_values[3] * kernel_values[12 + j] + sum[i * 4 + j];
}
}
}

for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
const ivec3 pos = pos_shared[offset_pos_index(index)];
if (all(lessThan(pos, out_limits.xyz))) {
imageStore(t_out, pos, op(sum[i], out_min, out_max));
const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos[2]);
if (all(lessThan(pos_l, out_limits.xyz))) {
imageStore(t_out, pos_l, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
}
}
}
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ conv2d_pw:
OPERATOR: X
NDIM: 3
DTYPE: float
TILE_SIZE_X: 2
TILE_SIZE_Y: 2
TILE_SIZE_X: 1
TILE_SIZE_Y: 4
generate_variant_forall:
DTYPE:
- VALUE: half
Expand Down
163 changes: 163 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_type(DTYPE)}

#define TILE_SIZE_X ${TILE_SIZE_X}
#define TILE_SIZE_Y ${TILE_SIZE_Y}

#define op(X, A, B) ${OPERATOR}

#include "indexing_utils.h"

layout(std430) buffer;

${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")}
${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")}
${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")}
${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")}

layout(push_constant) uniform restrict Block {
ivec4 out_limits;
ivec2 stride;
ivec2 padding;
int in_group_size;
int dummy_padding;
float out_min;
float out_max;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#extension GL_EXT_control_flow_attributes : require

/*
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
* output tile for pointwise convolution is more efficient because the kernel
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
*/
void main() {
const int out_limits_scaled[2] = {out_limits.x + (TILE_SIZE_X - 1) * TILE_SIZE_X, out_limits.y + (TILE_SIZE_Y - 1) * TILE_SIZE_Y};

const int div_by_x = int(gl_GlobalInvocationID.x / out_limits_scaled[0]);
const int out_pos[3] = {int(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x, int(gl_GlobalInvocationID.y)};

// If the top left position is out of bounds, then this invocation will have
// no work to do.
if (out_pos[1] >= out_limits_scaled[1] || out_pos[2] >= out_limits.z) {
return;
}

// Output position for TILE_SIZE = 2
// +--------+--------+
// | pos[0] | pos[1] |
// +--------+--------+
// | pos[2] | pos[3] |
// +--------+--------+
int pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
for (int x = 0; x < TILE_SIZE_X; ++x) {
pos[i * 2] = out_pos[0] * TILE_SIZE_X + x;
pos[i * 2 + 1] = out_pos[1] * TILE_SIZE_Y + y;
i++;
}
}

// Final output array where each element is a tensor value.
// Tuple of consecutive 4 elements represents a single output texel.
float sum[TILE_SIZE_X * TILE_SIZE_Y * 4];

const vec4 bias = texelFetch(t_bias, ivec2(out_pos[2], 0), 0);

// Initialize the output array with the bias value
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i += 4) {
sum[i] = bias.x;
sum[i + 1] = bias.y;
sum[i + 2] = bias.z;
sum[i + 3] = bias.w;
}

int z4 = 0;
// Since the kernel is 1x1, we only have to loop over the depth dimension.
for (int z = 0; z < in_group_size; z += 4, ++z4) {
// During prepacking, the weight tensor has been permuted so that the
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
// the z-axis.
float kernel_values[4 * 4]; // 4 channels, 4 elements per channel

// Load kernel values from texels to array
[[unroll]] for (int i = 0; i < 4; ++i) {
const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos[2]), 0);
kernel_values[i * 4 + 0] = k_tex.x;
kernel_values[i * 4 + 1] = k_tex.y;
kernel_values[i * 4 + 2] = k_tex.z;
kernel_values[i * 4 + 3] = k_tex.w;
}

for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
const vec4 in_tex = texelFetch(t_in, ivec3(pos[i * 2], pos[i * 2 + 1], z4), 0);
// Load the input texel into an array
float tex_values[4];
tex_values[0] = in_tex.x;
tex_values[1] = in_tex.y;
tex_values[2] = in_tex.z;
tex_values[3] = in_tex.w;

// For 2x2 tile size algorithm works as follows.
// To explain the calculations below, the contents of one in_tex and the
// group of 4 texels loaded from t_kernel are shown:
//
// in_tex t_kernel
// -x-> ---x--->
// +---+ +----+----+----+----+
// ^ | w | ^ | D0 | D1 | D2 | D3 |
// | +---+ | +----+----+----+----+
// | | z | | | C0 | C1 | C2 | C3 |
// z +---+ z +----+----+----+----+
// | | y | | | B0 | B2 | B2 | B3 |
// | +---+ | +----+----+----+----+
// | x | | A0 | A1 | A2 | A3 |
// +---+ +----+----+----+----+
//
// In the t_kernel graphic, cells sharing the same letter are from
// the same batch/output channel index, and the number denotes a unique
// channel index. To calculate the output texel, the following
// calculation is performed:
//
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
// | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 |
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
// | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 |
// +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+
// | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 |
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
// | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 |
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
//
// which is what is expressed in the following calculations. This is done
// for each output position.
for (int j = 0; j < 4; ++j) {
sum[i * 4 + j] = tex_values[0] * kernel_values[0 + j] + sum[i * 4 + j];
sum[i * 4 + j] = tex_values[1] * kernel_values[4 + j] + sum[i * 4 + j];
sum[i * 4 + j] = tex_values[2] * kernel_values[8 + j] + sum[i * 4 + j];
sum[i * 4 + j] = tex_values[3] * kernel_values[12 + j] + sum[i * 4 + j];
}
}
}

for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos[2]);
if (all(lessThan(pos_l, out_limits.xyz))) {
imageStore(t_out, pos_l, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
}
}
}
21 changes: 21 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

conv2d_pw_s1p0:
parameter_names_with_default_values:
OPERATOR: X
NDIM: 3
DTYPE: float
TILE_SIZE_X: 1
TILE_SIZE_Y: 4
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: conv2d_pw_s1p0
- NAME: conv2d_pw_s1p0_clamp
OPERATOR: clamp(X, A, B)
Loading
Loading