Skip to content

Commit 8aab7d0

Browse files
[ET-VK] Use push constants for image and buffer to nchw prepack nodes. (#11371)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11305 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/109/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/109/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/109/orig @diff-train-skip-merge --------- Co-authored-by: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com>
1 parent 6415c15 commit 8aab7d0

16 files changed

+83
-42
lines changed

backends/vulkan/runtime/graph/ops/glsl/bitw8_image_to_nchw_nobitw8buffer.glsl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,15 @@ layout(std430) buffer;
2020

2121
${layout_declare_buffer(B, "w", "nchw_out", "int")}
2222
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
23-
${layout_declare_ubo(B, "ivec4", "tensor_sizes")}
24-
${layout_declare_ubo(B, "int", "out_numel")}
23+
24+
$if USE_PUSH_CONST:
25+
layout(push_constant) uniform restrict Block {
26+
ivec4 tensor_sizes;
27+
int out_numel;
28+
};
29+
$else:
30+
${layout_declare_ubo(B, "ivec4", "tensor_sizes")}
31+
${layout_declare_ubo(B, "int", "out_numel")}
2532

2633
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2734

backends/vulkan/runtime/graph/ops/glsl/bitw8_image_to_nchw_nobitw8buffer.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ bitw8_image_to_nchw_nobitw8buffer:
88
parameter_names_with_default_values:
99
STORAGE: texture3d
1010
DTYPE: int8
11+
USE_PUSH_CONST: True
1112
generate_variant_forall:
1213
STORAGE:
1314
- VALUE: texture2d
@@ -17,3 +18,5 @@ bitw8_image_to_nchw_nobitw8buffer:
1718
- VALUE: uint8
1819
shader_variants:
1920
- NAME: bitw8_image_to_nchw_nobitw8buffer
21+
- NAME: bitw8_image_to_nchw_nobitw8buffer_no_pc
22+
USE_PUSH_CONST: False

backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,17 @@ layout(std430) buffer;
1212

1313
${layout_declare_tensor(0, "w", "nchw_buf", DTYPE, STORAGE)}
1414
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
15-
${layout_declare_ubo(2, "ivec4", "in_sizes")}
16-
${layout_declare_ubo(3, "ivec4", "in_strides")}
17-
${layout_declare_ubo(4, "int", "numel")}
15+
16+
$if USE_PUSH_CONST:
17+
layout(push_constant) uniform restrict Block {
18+
ivec4 in_sizes;
19+
ivec4 in_strides;
20+
int numel;
21+
};
22+
$else:
23+
${layout_declare_ubo(2, "ivec4", "in_sizes")}
24+
${layout_declare_ubo(3, "ivec4", "in_strides")}
25+
${layout_declare_ubo(4, "int", "numel")}
1826

1927
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2028

backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ buffer_to_nchw:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
STORAGE: buffer
11+
USE_PUSH_CONST: True
1112
generate_variant_forall:
1213
DTYPE:
1314
- VALUE: half
@@ -17,3 +18,5 @@ buffer_to_nchw:
1718
- VALUE: uint8
1819
shader_variants:
1920
- NAME: buffer_to_nchw
21+
- NAME: buffer_to_nchw_no_pc
22+
USE_PUSH_CONST: False

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@ layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
2626
BUF_T buffer_in[];
2727
};
2828

29-
layout(set = 0, binding = 2) uniform PRECISION restrict Sizes {
29+
layout(push_constant) uniform PRECISION restrict Block {
3030
ivec4 sizes;
31-
};
32-
33-
layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes {
3431
ivec4 original_sizes;
3532
};
3633

backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@ layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
2626
BUF_T buffer_in[];
2727
};
2828

29-
layout(set = 0, binding = 2) uniform PRECISION restrict Sizes {
29+
layout(push_constant) uniform PRECISION restrict Block {
3030
ivec4 sizes;
31-
};
32-
33-
layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes {
3431
ivec4 original_sizes;
3532
};
3633

backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@ layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
2626
BUF_T buffer_in[];
2727
};
2828

29-
layout(set = 0, binding = 2) uniform PRECISION restrict Sizes {
29+
layout(push_constant) uniform PRECISION restrict Block {
3030
ivec4 sizes;
31-
};
32-
33-
layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes {
3431
ivec4 original_sizes;
3532
};
3633

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,17 @@ layout(std430) buffer;
2121

2222
${layout_declare_buffer(B, "w", "buf_out", DTYPE)}
2323
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
24-
${layout_declare_ubo(B, "ivec4", "sizes")}
25-
$if not TO_STAGING:
26-
${layout_declare_ubo(B, "ivec4", "buf_strides")}
24+
25+
$if USE_PUSH_CONST:
26+
layout(push_constant) uniform restrict Block {
27+
ivec4 sizes;
28+
$if not TO_STAGING:
29+
ivec4 buf_strides;
30+
};
31+
$else:
32+
${layout_declare_ubo(B, "ivec4", "sizes")}
33+
$if not TO_STAGING:
34+
${layout_declare_ubo(B, "ivec4", "buf_strides")}
2735

2836
#include "indexing_utils.h"
2937

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ image_to_nchw:
99
DTYPE: float
1010
STORAGE: texture3d
1111
TO_STAGING: True
12+
USE_PUSH_CONST: True
1213
generate_variant_forall:
1314
DTYPE:
1415
- VALUE: half
@@ -22,3 +23,11 @@ image_to_nchw:
2223
STORAGE: texture2d
2324
- NAME: clone_image_to_buffer
2425
TO_STAGING: False
26+
- NAME: image_to_nchw_no_pc_texture3d
27+
USE_PUSH_CONST: False
28+
- NAME: image_to_nchw_no_pc_texture2d
29+
STORAGE: texture2d
30+
USE_PUSH_CONST: False
31+
- NAME: clone_image_to_buffer_no_pc
32+
TO_STAGING: False
33+
USE_PUSH_CONST: False

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ void add_image_to_buffer_node(
8888
// Input and Outputs
8989
{{buffer, vkapi::kWrite}, {image, vkapi::kRead}},
9090
// Parameter Buffers
91-
{graph.sizes_ubo(image), graph.strides_ubo(buffer)},
92-
// Push Constants
9391
{},
92+
// Push Constants
93+
{graph.sizes_pc_of(image), graph.strides_pc_of(buffer)},
9494
// Specialization Constants
9595
{graph.hashed_layout_of(image)},
9696
// Resize Args

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,18 +211,20 @@ ValueRef prepack_weights(
211211
vkapi::ShaderInfo shader =
212212
get_conv2d_shader(graph, *t, /*prepack_weights = */ true, method, vref);
213213

214+
const auto original_sizes_pc =
215+
utils::make_ivec4(original_sizes, /*reverse = */ true);
214216
graph.prepack_nodes().emplace_back(new PrepackNode(
215217
graph,
216218
shader,
217219
graph.create_global_wg_size(v),
218220
graph.create_local_wg_size(v),
219221
vref,
220222
v,
221-
{t->sizes_ubo(),
222-
graph.create_params_buffer(
223-
utils::make_ivec4(original_sizes, /*reverse = */ true))},
223+
{},
224224
// Specialization constants
225-
{SV(t->packed_dim())}));
225+
{SV(t->packed_dim())},
226+
{graph.sizes_pc_of(v),
227+
PushConstantDataInfo(&original_sizes_pc, sizeof(original_sizes_pc))}));
226228

227229
return v;
228230
}

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,18 @@ void add_tensor_to_staging_node(
113113
vkapi::ShaderInfo shader = get_tensor_to_nchw_shader(
114114
*graph.get_tensor(in_tensor), graph.int8_buffers_enabled());
115115

116-
vkapi::ParamsBindList ubos;
116+
std::vector<PushConstantDataInfo> pcs;
117117
if (graph.is_buffer_storage(in_tensor)) {
118-
ubos.append(
119-
{graph.sizes_ubo(in_tensor),
120-
graph.strides_ubo(in_tensor),
121-
graph.numel_ubo(in_tensor)});
118+
pcs = {
119+
graph.sizes_pc_of(in_tensor),
120+
graph.strides_pc_of(in_tensor),
121+
graph.numel_pc_of(in_tensor)};
122122
} else {
123-
ubos.append({graph.sizes_ubo(in_tensor)});
123+
pcs = {graph.sizes_pc_of(in_tensor)};
124124
}
125125

126126
if (is_bitw8_shader(shader)) {
127-
ubos.append({graph.numel_ubo(in_tensor)});
127+
pcs.push_back(graph.numel_pc_of(in_tensor));
128128
}
129129

130130
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
@@ -135,9 +135,9 @@ void add_tensor_to_staging_node(
135135
// Input and Outputs
136136
{{out_staging, vkapi::kWrite}, {in_tensor, vkapi::kRead}},
137137
// Parameter Buffers
138-
ubos,
139-
// Push Constants
140138
{},
139+
// Push Constants
140+
pcs,
141141
// Specialization Constants
142142
{graph.hashed_layout_of(in_tensor)},
143143
// Resize Args

backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,25 +59,35 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader(
5959

6060
vkapi::ShaderInfo get_tensor_to_nchw_shader(
6161
const api::vTensor& v_src,
62-
bool int8_buffer_enabled) {
62+
bool int8_buffer_enabled,
63+
bool push_constant_variant) {
6364
std::string kernel_name;
6465
kernel_name.reserve(kShaderNameReserve);
6566

6667
if (is_bitw8(v_src.dtype()) && v_src.storage_type() != utils::kBuffer &&
6768
!int8_buffer_enabled) {
6869
kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
70+
if (!push_constant_variant) {
71+
kernel_name += "_no_pc";
72+
}
6973
add_storage_type_suffix(kernel_name, v_src);
7074
add_dtype_suffix(kernel_name, v_src);
7175
return VK_KERNEL_FROM_STR(kernel_name);
7276
}
7377

7478
if (v_src.storage_type() == utils::kBuffer) {
7579
kernel_name = "buffer_to_nchw";
80+
if (!push_constant_variant) {
81+
kernel_name += "_no_pc";
82+
}
7683
add_dtype_suffix(kernel_name, v_src);
7784
return VK_KERNEL_FROM_STR(kernel_name);
7885
}
7986

8087
kernel_name = "image_to_nchw";
88+
if (!push_constant_variant) {
89+
kernel_name += "_no_pc";
90+
}
8191
add_storage_type_suffix(kernel_name, v_src);
8292
add_dtype_suffix(kernel_name, v_src);
8393

backends/vulkan/runtime/graph/ops/utils/StagingUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader(
1818
bool push_constant_variant = true);
1919
vkapi::ShaderInfo get_tensor_to_nchw_shader(
2020
const api::vTensor& v_src,
21-
bool int8_buffer_enabled = true);
21+
bool int8_buffer_enabled = true,
22+
bool push_constant_variant = true);
2223

2324
} // namespace vkcompute

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ void record_buffer_to_nchw_op(
5151
vkapi::VulkanBuffer& dst_buffer) {
5252
vkapi::PipelineBarrier pipeline_barrier{};
5353
context->submit_compute_job(
54-
get_tensor_to_nchw_shader(v_src),
54+
get_tensor_to_nchw_shader(v_src, true, false),
5555
pipeline_barrier,
5656
{uint32_t(v_src.numel()), 1, 1},
5757
{64, 1, 1},
@@ -99,7 +99,7 @@ void record_image_to_nchw_op(
9999
vkapi::SpecVarList specialization_constants = {v_src.hashed_layout()};
100100

101101
context->submit_compute_job(
102-
get_tensor_to_nchw_shader(v_src),
102+
get_tensor_to_nchw_shader(v_src, true, false),
103103
pipeline_barrier,
104104
v_src.logical_limits(),
105105
adaptive_work_group_size(v_src.logical_limits()),
@@ -119,7 +119,7 @@ void record_bitw8_image_to_nchw_nobitw8buffer_op(
119119
uint32_t buffer_len = utils::safe_downcast<uint32_t>(dst_buffer.numel() / 4);
120120
utils::uvec3 global_wg_size = {buffer_len, 1, 1};
121121

122-
std::string kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
122+
std::string kernel_name = "bitw8_image_to_nchw_nobitw8buffer_no_pc";
123123
add_storage_type_suffix(kernel_name, v_src);
124124
add_dtype_suffix(kernel_name, v_src);
125125

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,8 +1640,7 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
16401640
out.staging = graph.set_output_tensor(out.value);
16411641

16421642
// +1: staging buffer input tensor
1643-
// +1: staging buffer for the output tensor
1644-
expected_vma_allocation_count += 2;
1643+
expected_vma_allocation_count += 1;
16451644
EXPECT_EQ(get_vma_allocation_count(), expected_vma_allocation_count);
16461645

16471646
graph.prepare();

0 commit comments

Comments
 (0)