Skip to content

Commit 1a11267

Browse files
authored
[ET-VK] Introduce generalized shaders for transfer ops and use it for select and slice (#11304)
## Changes * Introduce `transfer_buffer.glsl` and `transfer_texture.glsl`, and `Transfer.cpp` which generalizes shaders where each element of the output is copied from a unique element of the input. * Update `Slice.cpp` and `Select.cpp` to use `Transfer.cpp` * Remove old implementations of slice and select ## Motivation With this new implementation, the op can now support both buffers and textures of any packing. There are also benefits of code consolidation. Differential Revision: [D75686050](https://our.internmc.facebook.com/intern/diff/D75686050/)
1 parent e088221 commit 1a11267

32 files changed

+792
-873
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,14 +492,24 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
492492
const ValueRef idx) {
493493
if (values_.at(idx).isInt()) {
494494
const int32_t val = extract_scalar<int32_t>(idx);
495-
create_params_buffer(val);
495+
return create_params_buffer(val);
496496
} else if (values_.at(idx).isSymInt()) {
497497
SymIntPtr symint = get_symint(idx);
498498
return vkapi::BufferBindInfo(symint->gpu_buffer.buffer());
499499
}
500500
VK_THROW("Cannot create a int param buffer for the given value");
501501
}
502502

503+
vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
504+
const ValueRef idx,
505+
const int32_t default_val) {
506+
if (values_.at(idx).isNone()) {
507+
return create_params_buffer(default_val);
508+
} else {
509+
return get_or_create_int_param_buffer(idx);
510+
}
511+
}
512+
503513
void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
504514
get_symint(idx)->set(val);
505515
}
@@ -693,6 +703,12 @@ void ComputeGraph::resize_input(
693703
get_tensor(io_val.value)->virtual_resize(new_sizes);
694704
}
695705

706+
void ComputeGraph::virtual_resize(
707+
const ValueRef idx,
708+
const std::vector<int64_t>& new_sizes) {
709+
get_tensor(idx)->virtual_resize(new_sizes);
710+
}
711+
696712
void ComputeGraph::propagate_resize() {
697713
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
698714
node->trigger_resize(this);

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,19 @@ class ComputeGraph final {
398398
std::optional<T> extract_optional_scalar(const ValueRef idx) {
399399
if (val_is_none(idx)) {
400400
return ::std::nullopt;
401+
} else if (val_is_symint(idx)) {
402+
return utils::safe_downcast<T>(read_symint(idx));
403+
} else {
404+
return extract_scalar<T>(idx);
405+
}
406+
}
407+
408+
template <typename T>
409+
T extract_optional_scalar(const ValueRef idx, const T default_val) {
410+
if (val_is_none(idx)) {
411+
return default_val;
412+
} else if (val_is_symint(idx)) {
413+
return utils::safe_downcast<T>(read_symint(idx));
401414
} else {
402415
return extract_scalar<T>(idx);
403416
}
@@ -609,6 +622,10 @@ class ComputeGraph final {
609622
*/
610623
vkapi::BufferBindInfo get_or_create_int_param_buffer(const ValueRef idx);
611624

625+
vkapi::BufferBindInfo get_or_create_int_param_buffer(
626+
const ValueRef idx,
627+
const int32_t default_value);
628+
612629
void set_symint(const ValueRef idx, const int32_t val);
613630

614631
int32_t read_symint(const ValueRef idx);
@@ -753,6 +770,9 @@ class ComputeGraph final {
753770
//
754771

755772
void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);
773+
void virtual_resize(
774+
const ValueRef idx,
775+
const std::vector<int64_t>& new_sizes);
756776
void propagate_resize();
757777

758778
//
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef SELECT_GLSLH
10+
#define SELECT_GLSLH
11+
12+
/*
13+
* Enable the fast path if a texel loaded from the input texture can be used as
14+
* is to store to the output texture. The following conditions must be met:
15+
*
16+
* 1. The input and output textures have the same packed dimension.
17+
* 2. The selected_dim must not be the packed dimension of the input.
18+
* 3. The packed dimension of the input must "map" to the packed dimension of
19+
* the output. This occurs if selected_dim is greater than the packed dimension
20+
* of the input.
21+
*/
22+
bool can_use_fast_path() {
23+
if (out_packed_dim != in_packed_dim) {
24+
return false;
25+
}
26+
if (selected_dim <= in_packed_dim) {
27+
return false;
28+
}
29+
return true;
30+
}
31+
32+
/*
33+
* Given an output tensor index, return the corresponding input tensor index for
34+
* the select operator. This is done by "inserting" the select index at the
35+
* selected_dim in the input tensor index.
36+
*
37+
* A simple example is (note all tensor index are in WHCN order):
38+
* out_tidx = [7, 5, 9]
39+
* selected_dim = 2
40+
* index = 3
41+
* in_tidx = [7, 3, 5, 9]
42+
*
43+
* This function assumes that the following variables are defined in the layout:
44+
* - in_sizes
45+
* - selected_dim
46+
* - index
47+
*/
48+
ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) {
49+
ivec4 in_tidx = ivec4(0);
50+
51+
int adjusted_index = index;
52+
if (index < 0) {
53+
adjusted_index = index + in_sizes[selected_dim];
54+
}
55+
56+
// Handle different dimensions for selection
57+
if (selected_dim == 0) {
58+
// Select from width dimension
59+
in_tidx = ivec4(adjusted_index, out_tidx.x, out_tidx.y, out_tidx.z);
60+
} else if (selected_dim == 1) {
61+
// Select from height dimension
62+
in_tidx = ivec4(out_tidx.x, adjusted_index, out_tidx.y, out_tidx.z);
63+
} else if (selected_dim == 2) {
64+
// Select from channel dimension
65+
in_tidx = ivec4(out_tidx.x, out_tidx.y, adjusted_index, out_tidx.z);
66+
} else if (selected_dim == 3) {
67+
// Select from batch dimension
68+
in_tidx = ivec4(out_tidx.x, out_tidx.y, out_tidx.z, adjusted_index);
69+
}
70+
71+
return in_tidx;
72+
}
73+
74+
#endif // SELECT_GLSLH

backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl

Lines changed: 0 additions & 52 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl

Lines changed: 0 additions & 50 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.yaml

Lines changed: 0 additions & 10 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl

Lines changed: 0 additions & 65 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl

Lines changed: 0 additions & 62 deletions
This file was deleted.

0 commit comments

Comments
 (0)