Skip to content

Commit 03b1cc8

Browse files
authored
[ET-VK] Using push constants for buffer to image prepack nodes.
Differential Revision: D70102398 Pull Request resolved: #11252
1 parent aa3e2b1 commit 03b1cc8

13 files changed

+85
-35
lines changed

backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@ layout(std430) buffer;
2222

2323
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
2424
${layout_declare_buffer(B, "r", "nchw_in", "int")}
25-
${layout_declare_ubo(B, "ivec4", "sizes")}
25+
26+
$if USE_PUSH_CONST:
27+
layout(push_constant) uniform restrict Block {
28+
ivec4 sizes;
29+
};
30+
$else:
31+
${layout_declare_ubo(B, "ivec4", "sizes")}
2632

2733
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2834

backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ nchw_to_bitw8_image_nobitw8buffer:
88
parameter_names_with_default_values:
99
STORAGE: texture3d
1010
DTYPE: int8
11+
USE_PUSH_CONST: True
1112
generate_variant_forall:
1213
STORAGE:
1314
- VALUE: texture2d
@@ -17,3 +18,5 @@ nchw_to_bitw8_image_nobitw8buffer:
1718
- VALUE: uint8
1819
shader_variants:
1920
- NAME: nchw_to_bitw8_image_nobitw8buffer
21+
- NAME: nchw_to_bitw8_image_nobitw8buffer_no_pc
22+
USE_PUSH_CONST: False

backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,17 @@ layout(std430) buffer;
1212

1313
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
1414
${layout_declare_tensor(1, "r", "nchw_in", DTYPE, STORAGE)}
15-
${layout_declare_ubo(2, "ivec4", "out_sizes")}
16-
${layout_declare_ubo(3, "ivec4", "out_strides")}
17-
${layout_declare_ubo(4, "int", "numel")}
15+
16+
$if USE_PUSH_CONST:
17+
layout(push_constant) uniform restrict Block {
18+
ivec4 out_sizes;
19+
ivec4 out_strides;
20+
int numel;
21+
};
22+
$else:
23+
${layout_declare_ubo(2, "ivec4", "out_sizes")}
24+
${layout_declare_ubo(3, "ivec4", "out_strides")}
25+
${layout_declare_ubo(4, "int", "numel")}
1826

1927
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2028

backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ nchw_to_buffer:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
STORAGE: buffer
11+
USE_PUSH_CONST: True
1112
generate_variant_forall:
1213
DTYPE:
1314
- VALUE: half
@@ -17,3 +18,5 @@ nchw_to_buffer:
1718
- VALUE: uint8
1819
shader_variants:
1920
- NAME: nchw_to_buffer
21+
- NAME: nchw_to_buffer_no_pc
22+
USE_PUSH_CONST: False

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,17 @@ layout(std430) buffer;
2121

2222
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
2323
${layout_declare_buffer(B, "r", "buf_in", DTYPE)}
24-
${layout_declare_ubo(B, "ivec4", "sizes")}
25-
$if not FROM_STAGING:
26-
${layout_declare_ubo(B, "ivec4", "buf_strides")}
24+
25+
$if USE_PUSH_CONST:
26+
layout(push_constant) uniform restrict Block {
27+
ivec4 sizes;
28+
$if not FROM_STAGING:
29+
ivec4 buf_strides;
30+
};
31+
$else:
32+
${layout_declare_ubo(B, "ivec4", "sizes")}
33+
$if not FROM_STAGING:
34+
${layout_declare_ubo(B, "ivec4", "buf_strides")}
2735

2836
#include "indexing_utils.h"
2937

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ nchw_to_image:
99
STORAGE: texture3d
1010
DTYPE: float
1111
FROM_STAGING: True
12+
USE_PUSH_CONST: True
1213
generate_variant_forall:
1314
DTYPE:
1415
- VALUE: half
@@ -22,3 +23,11 @@ nchw_to_image:
2223
STORAGE: texture2d
2324
- NAME: clone_buffer_to_image
2425
FROM_STAGING: False
26+
- NAME: nchw_to_image_no_pc_texture3d
27+
USE_PUSH_CONST: False
28+
- NAME: nchw_to_image_no_pc_texture2d
29+
STORAGE: texture2d
30+
USE_PUSH_CONST: False
31+
- NAME: clone_buffer_to_image_no_pc
32+
FROM_STAGING: False
33+
USE_PUSH_CONST: False

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ void add_buffer_to_image_node(
105105
// Input and Outputs
106106
{{image, vkapi::kWrite}, {buffer, vkapi::kRead}},
107107
// Parameter Buffers
108-
{graph.sizes_ubo(image), graph.strides_ubo(buffer)},
109-
// Push Constants
110108
{},
109+
// Push Constants
110+
{graph.sizes_pc_of(image), graph.strides_pc_of(buffer)},
111111
// Specialization Constants
112112
{graph.hashed_layout_of(image)},
113113
// Resize Args

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,10 @@ ValueRef prepack_biases(
106106
graph.create_local_wg_size(v),
107107
vref,
108108
v,
109-
{t->sizes_ubo()},
109+
{},
110110
// Specialization constants
111-
{t->hashed_layout()}));
111+
{t->hashed_layout()},
112+
{graph.sizes_pc_of(v)}));
112113

113114
return v;
114115
}

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ void add_staging_to_tensor_node(
2828
vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(
2929
*graph.get_tensor(out_tensor), graph.int8_buffers_enabled());
3030

31-
vkapi::ParamsBindList ubos;
31+
std::vector<PushConstantDataInfo> pcs;
3232
if (graph.is_buffer_storage(out_tensor)) {
33-
ubos.append(
34-
{graph.sizes_ubo(out_tensor),
35-
graph.strides_ubo(out_tensor),
36-
graph.numel_ubo(out_tensor)});
33+
pcs = {
34+
graph.sizes_pc_of(out_tensor),
35+
graph.strides_pc_of(out_tensor),
36+
graph.numel_pc_of(out_tensor)};
3737
} else {
38-
ubos.append({graph.sizes_ubo(out_tensor)});
38+
pcs = {graph.sizes_pc_of(out_tensor)};
3939
}
4040

4141
graph.execute_nodes().emplace_back(new DispatchNode(
@@ -46,9 +46,9 @@ void add_staging_to_tensor_node(
4646
// Input and Outputs
4747
{{out_tensor, vkapi::kWrite}, {in_staging, vkapi::kRead}},
4848
// Parameter Buffers
49-
ubos,
50-
// Push Constants
5149
{},
50+
// Push Constants
51+
pcs,
5252
// Specialization Constants
5353
{graph.hashed_layout_of(out_tensor)},
5454
// Resize Args
@@ -127,14 +127,14 @@ void add_prepack_standard_node(
127127
vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(
128128
*graph.get_tensor(tensor), graph.int8_buffers_enabled());
129129

130-
vkapi::ParamsBindList ubos;
130+
std::vector<PushConstantDataInfo> pcs;
131131
if (graph.is_buffer_storage(tensor)) {
132-
ubos.append(
133-
{graph.sizes_ubo(tensor),
134-
graph.strides_ubo(tensor),
135-
graph.numel_ubo(tensor)});
132+
pcs = {
133+
graph.sizes_pc_of(tensor),
134+
graph.strides_pc_of(tensor),
135+
graph.numel_pc_of(tensor)};
136136
} else {
137-
ubos.append({graph.sizes_ubo(tensor)});
137+
pcs = {graph.sizes_pc_of(tensor)};
138138
}
139139

140140
int transpose_hw_spec = transpose_hw ? 1 : 0;
@@ -148,9 +148,10 @@ void add_prepack_standard_node(
148148
tensor_data,
149149
tensor,
150150
// Parameter Buffers
151-
ubos,
151+
{},
152152
// Specialization Constants
153-
{graph.hashed_layout_of(tensor), transpose_hw_spec}));
153+
{graph.hashed_layout_of(tensor), transpose_hw_spec},
154+
pcs));
154155
}
155156

156157
ValueRef prepack_standard(

backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,35 @@ bool is_bitw8(vkapi::ScalarType dtype) {
2222

2323
vkapi::ShaderInfo get_nchw_to_tensor_shader(
2424
const api::vTensor& v_dst,
25-
const bool int8_buffer_enabled) {
25+
bool int8_buffer_enabled,
26+
bool push_constant_variant) {
2627
std::string kernel_name;
2728
kernel_name.reserve(kShaderNameReserve);
2829

2930
if (is_bitw8(v_dst.dtype()) && v_dst.storage_type() != utils::kBuffer &&
3031
!int8_buffer_enabled) {
3132
kernel_name = "nchw_to_bitw8_image_nobitw8buffer";
33+
if (!push_constant_variant) {
34+
kernel_name += "_no_pc";
35+
}
3236
add_storage_type_suffix(kernel_name, v_dst);
3337
add_dtype_suffix(kernel_name, v_dst);
3438
return VK_KERNEL_FROM_STR(kernel_name);
3539
}
3640

3741
if (v_dst.storage_type() == utils::kBuffer) {
3842
kernel_name = "nchw_to_buffer";
43+
if (!push_constant_variant) {
44+
kernel_name += "_no_pc";
45+
}
3946
add_dtype_suffix(kernel_name, v_dst);
4047
return VK_KERNEL_FROM_STR(kernel_name);
4148
}
4249

4350
kernel_name = "nchw_to_image";
51+
if (!push_constant_variant) {
52+
kernel_name += "_no_pc";
53+
}
4454
add_storage_type_suffix(kernel_name, v_dst);
4555
add_dtype_suffix(kernel_name, v_dst);
4656

backends/vulkan/runtime/graph/ops/utils/StagingUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ namespace vkcompute {
1414

1515
vkapi::ShaderInfo get_nchw_to_tensor_shader(
1616
const api::vTensor& v_dst,
17-
bool int8_buffer_enabled = true);
17+
bool int8_buffer_enabled = true,
18+
bool push_constant_variant = true);
1819
vkapi::ShaderInfo get_tensor_to_nchw_shader(
1920
const api::vTensor& v_src,
2021
bool int8_buffer_enabled = true);

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void record_nchw_to_buffer_op(
2828
vkapi::PipelineBarrier pipeline_barrier{};
2929

3030
context->submit_compute_job(
31-
get_nchw_to_tensor_shader(v_dst),
31+
get_nchw_to_tensor_shader(v_dst, true, false),
3232
pipeline_barrier,
3333
{uint32_t(v_dst.numel()), 1, 1},
3434
{64, 1, 1},
@@ -74,7 +74,9 @@ void record_nchw_to_image_op(
7474

7575
context->submit_compute_job(
7676
get_nchw_to_tensor_shader(
77-
v_dst, context->adapter_ptr()->has_full_int8_buffers_support()),
77+
v_dst,
78+
context->adapter_ptr()->has_full_int8_buffers_support(),
79+
false),
7880
pipeline_barrier,
7981
v_dst.logical_limits(),
8082
adaptive_work_group_size(v_dst.logical_limits()),

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,8 +1601,7 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
16011601
/*shared_object_idx = */ 4);
16021602

16031603
// +2: t.sizes_ubo() for each staging shader
1604-
// +2: staging buffer for each input tensor
1605-
expected_vma_allocation_count += 4;
1604+
expected_vma_allocation_count += 2;
16061605
EXPECT_EQ(get_vma_allocation_count(), expected_vma_allocation_count);
16071606

16081607
ValueRef c = graph.add_tensor(
@@ -1622,8 +1621,7 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
16221621
/*shared_object_idx = */ 2);
16231622

16241623
// +1: t.sizes_ubo() uniform buffer for staging shader
1625-
// +1: staging buffer for the input tensor
1626-
expected_vma_allocation_count += 2;
1624+
expected_vma_allocation_count += 1;
16271625
EXPECT_EQ(get_vma_allocation_count(), expected_vma_allocation_count);
16281626

16291627
ValueRef e = graph.add_tensor(

0 commit comments

Comments
 (0)