Skip to content

Commit 401016d

Browse files
authored
[ET-VK][ez] Fix handling of assert ops (#11349)
## Changes * Apply `RemoveAssertsTransform` as part of `vulkan_preprocess` * Do not call `RemoveAssertsTransform` before lowering the graph * Register ops related to asserts to the operator registry as ephemeral ops ## Motivation assert ops are not implemented in Vulkan, so previously `RemoveAssertsTransform()` is called on the graph before the lowering process. However, it turns out that the assertion ops are required to properly handle dynamic shapes, because they place constraints on the possible range of symbolic integers. If they are not present, then re-tracing the graph during a recompile (which may occur during a graph transform pass) may fail. Therefore, instead of calling the transform before lowering, call it inside vulkan_preprocess after a point where subsequent passes will not attempt to trace the graph. Differential Revision: [D75686048](https://our.internmc.facebook.com/intern/diff/D75686048/)
1 parent c7896df commit 401016d

File tree

6 files changed

+17
-9
lines changed

6 files changed

+17
-9
lines changed

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.exir import ExportedProgram
1818
from executorch.exir.dialects._ops import ops as exir_ops
1919
from executorch.exir.pass_base import ExportPass, PassResult
20+
from executorch.exir.passes import dead_code_elimination_pass
2021

2122
#################
2223
## linear_qcnw ##
@@ -224,6 +225,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
224225
)
225226

226227
graph_module.recompile()
227-
graph_module = super().call(graph_module).graph_module
228+
dead_code_elimination_pass(graph_module)
228229

230+
# Re-trace the graph since new nodes were (potentially) inserted
231+
graph_module = super().call(graph_module).graph_module
229232
return PassResult(graph_module, True)

backends/vulkan/op_registry.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,13 @@ def update_features_impl(op: OpKey):
231231
# Symbolic integer ops
232232
torch.ops.aten.sym_size.int,
233233
operator.add,
234+
operator.lt,
235+
operator.gt,
236+
operator.ge,
237+
operator.le,
238+
# Guard and assert ops
239+
torch.ops.aten._assert_scalar.default,
240+
torch.ops.aten.sym_constrain_range_for_size.default,
234241
]
235242
)
236243
def register_ephemeral_op(features: OpFeatures):

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
146146
def node_is_compatible(
147147
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
148148
) -> Tuple[bool, str]:
149-
if utils.is_symint_node(node):
150-
return node.target in vulkan_supported_ops, "Op is compatible"
151-
elif utils.is_tensor_node(node):
149+
if utils.is_tensor_node(node):
152150
return self.op_node_is_compatible(node, features=features)
151+
# For non-tensor nodes, just check if the op is registered
152+
elif hasattr(node, "target"):
153+
return node.target in vulkan_supported_ops, "Op is compatible"
153154

154155
return False, f"Unsupported node type: {node.format_node()}"
155156

backends/vulkan/vulkan_preprocess.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
SqueezeUnsqueezeInputs,
3030
TagMemoryMetaPass,
3131
)
32+
from executorch.backends.vulkan._passes.remove_asserts import RemoveAssertsTransform
3233

3334
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
3435
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
@@ -172,6 +173,7 @@ def preprocess( # noqa: C901
172173
program = apply_passes(
173174
program,
174175
[
176+
RemoveAssertsTransform(),
175177
# Since this pass may replace a scalar argument with a tensor argument,
176178
# this pass may result in a non ATen compliant graph structure.
177179
RemoveLocalScalarDenseOpsTransform(),

examples/models/llama/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ runtime.python_library(
148148
":source_transformation",
149149
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
150150
"//caffe2:torch",
151-
"//executorch/backends/vulkan/_passes:vulkan_passes",
152151
"//executorch/exir/passes:init_mutable_pass",
153152
"//executorch/examples/models:model_base",
154153
"//executorch/examples/models:models",

examples/models/llama/export_llama_lib.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import pkg_resources
2525
import torch
2626

27-
from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
2827
from executorch.devtools.backend_debug import print_delegation_info
2928

3029
from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func
@@ -880,9 +879,6 @@ def _to_edge_and_lower_llama( # noqa: C901
880879
)
881880
modelname = f"vulkan_{modelname}"
882881

883-
# Need to remove asserts from the graph to prevent graph breaks
884-
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())
885-
886882
if mps:
887883
partitioners.append(get_mps_partitioner(use_kv_cache))
888884
modelname = f"mps_{modelname}"

0 commit comments

Comments
 (0)