Skip to content

Commit ca68afc

Browse files
authored
Qualcomm AI Engine Direct - GA MobileVit V2 (#11279)
### Summary - GA Model: MobileVit V2 (**PLEASE use QNN2.29** when running this model or else the accuracy is bad.) - Support Fold/Unfold Operator (Aten is Col2Im/Im2Col) with passes ### Test plan #### Accuracy top1: ~50% top5: ~85% #### Speed SM8750: 2.34ms/inf ### Script python examples/qualcomm/oss_scripts/mobilevit_v2.py -b build-android -H $HOST -s $DEVICE --dataset ../imagenet-mini/val/ ### Test plan E2E UT
1 parent e2ce885 commit ca68afc

File tree

13 files changed

+485
-27
lines changed

13 files changed

+485
-27
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .convert_square_to_pow import ConvertSquareToPow
1414
from .decompose_any import DecomposeAny
1515
from .decompose_cdist import DecomposeCDist
16+
from .decompose_col_im import DecomposeColIm
1617
from .decompose_einsum import DecomposeEinsum
1718
from .decompose_expm1 import DecomposeExpM1
1819
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
@@ -49,6 +50,7 @@
4950
ConvertSquareToPow,
5051
DecomposeAny,
5152
DecomposeCDist,
53+
DecomposeColIm,
5254
DecomposeEinsum,
5355
DecomposeExpM1,
5456
DecomposeLinalgVectorNorm,
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
from .utils import copy_meta
11+
12+
13+
class DecomposeColIm(ExportPass):
14+
"""
15+
Decompose im2col(unfold) to pixel_unshuffle + view_copy
16+
Decompose col2im(fold) to view_copy + pixel_shuffle
17+
"""
18+
19+
def __init__(self):
20+
super(DecomposeColIm, self).__init__()
21+
self.im2col_op = exir_ops.edge.aten.im2col.default
22+
self.col2im_op = exir_ops.edge.aten.col2im.default
23+
self.pixel_unshuffle_op = exir_ops.edge.aten.pixel_unshuffle.default
24+
self.pixel_shuffle_op = exir_ops.edge.aten.pixel_shuffle.default
25+
self.view_copy_op = exir_ops.edge.aten.view_copy.default
26+
27+
def _decompose_im2col(self, graph_module: torch.fx.GraphModule):
28+
for node in graph_module.graph.nodes:
29+
if node.target == self.im2col_op:
30+
input_node = node.args[0]
31+
kernel_size = node.args[1]
32+
stride = node.args[4]
33+
batch_size = node.meta["val"].shape[0]
34+
assert (
35+
stride == kernel_size
36+
), "im2col can only be converted when stride == kernel_size"
37+
assert (
38+
input_node.meta["val"].dim() == 4
39+
), "im2col can only be converted when input dims == 4"
40+
assert (
41+
kernel_size[0] == kernel_size[1]
42+
), "im2col can only be converted when kernel height == width"
43+
users = list(node.users.keys())
44+
with graph_module.graph.inserting_after(input_node):
45+
pixel_unshuffle_node = graph_module.graph.create_node(
46+
"call_function",
47+
self.pixel_unshuffle_op,
48+
(input_node, kernel_size[0]),
49+
)
50+
pixel_unshuffle_node.meta = copy_meta(node.meta)
51+
orig_height = input_node.meta["val"].shape[2]
52+
orig_width = input_node.meta["val"].shape[3]
53+
54+
pixel_unshuffle_node.meta["val"] = pixel_unshuffle_node.meta[
55+
"val"
56+
].reshape(
57+
batch_size,
58+
-1,
59+
orig_height // kernel_size[0],
60+
orig_width // kernel_size[1],
61+
)
62+
63+
with graph_module.graph.inserting_after(pixel_unshuffle_node):
64+
view_copy_node = graph_module.graph.create_node(
65+
"call_function",
66+
self.view_copy_op,
67+
(pixel_unshuffle_node, tuple(node.meta["val"].shape)),
68+
)
69+
view_copy_node.meta = copy_meta(node.meta)
70+
for user in users:
71+
user.replace_input_with(node, view_copy_node)
72+
73+
def _decompose_col2im(self, graph_module: torch.fx.GraphModule):
74+
for node in graph_module.graph.nodes:
75+
if node.target == self.col2im_op:
76+
input_node = node.args[0]
77+
output_size = node.args[1]
78+
kernel_size = node.args[2]
79+
stride = node.args[5]
80+
batch_size = node.meta["val"].shape[0]
81+
assert (
82+
stride == kernel_size
83+
), "col2im can only be converted when stride == kernel_size"
84+
assert (
85+
node.meta["val"].dim() == 4
86+
), "col2im can only be converted when output dims == 4"
87+
assert (
88+
kernel_size[0] == kernel_size[1]
89+
), "col2im can only be converted when kernel height == width"
90+
users = list(node.users.keys())
91+
with graph_module.graph.inserting_after(input_node):
92+
view_tensor = input_node.meta["val"].reshape(
93+
batch_size,
94+
-1,
95+
output_size[0] // kernel_size[0],
96+
output_size[1] // kernel_size[1],
97+
)
98+
view_copy_node = graph_module.graph.create_node(
99+
"call_function",
100+
self.view_copy_op,
101+
(input_node, tuple(view_tensor.shape)),
102+
)
103+
view_copy_node.meta = copy_meta(node.meta)
104+
view_copy_node.meta["val"] = view_tensor
105+
106+
with graph_module.graph.inserting_after(view_copy_node):
107+
pixel_shuffle_node = graph_module.graph.create_node(
108+
"call_function",
109+
self.pixel_shuffle_op,
110+
(view_copy_node, kernel_size[0]),
111+
)
112+
pixel_shuffle_node.meta = copy_meta(node.meta)
113+
114+
for user in users:
115+
user.replace_input_with(node, pixel_shuffle_node)
116+
117+
def call(self, graph_module: torch.fx.GraphModule):
118+
self._decompose_im2col(graph_module)
119+
self._decompose_col2im(graph_module)
120+
graph_module.recompile()
121+
return PassResult(graph_module, True)

backends/qualcomm/_passes/layout_transform.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,21 @@
2323
class LayoutTransform(ExportPass):
2424
"""
2525
QNN delegate requires channel last layout format, this pass aims to
26-
help generate the correct transformation by inserting fewest ammount of
26+
help generate the correct transformation by inserting fewest amount of
2727
'permute' operators in the graph.
28+
Please notice that permute op is inserted during qnn_preprocess.
29+
30+
Operations are divided into 3 categories: sensitive_layout, agnostic_layout, and pytorch_layout.
31+
sensitive_layout: These ops must be lowered to QNN in NHWC format. A permute(NCHW->NHWC) op will be inserted in front of the sensitive_layout op.
32+
agnostic_layout: These ops are agnostic to layout format, which means it can be passed to QNN in either NCHW or NHWC format.
33+
pytorch_layout: These ops must be lowered to QNN in NCHW format. A permute(NHWC->NCHW) op will be inserted in front of the pytorch_layout op.
34+
35+
For optimization purposes, permute is only inserted when it is necessary to switch between sensitive_layout and pytorch_layout.
36+
For example, a model consists of three kinds of operations: conv(sensitive_layout), relu(agnostic_layout), and unsqueeze(pytorch_layout)
37+
If a graph originally looks like : in -> conv -> relu -> conv -> relu -> unsqueeze -> out
38+
After layout_transform pass: in -> permute(NCHW->NHWC) -> conv -> relu -> conv -> relu -> permute(NHWC->NCHW) -> unsqueeze -> out
39+
The reason for inserting the 1st permute is because conv is layout sensitive. Since relu is agnostic to layout, it doesn't matter what format is used.
40+
This format works fine until unsqueeze is encountered, which is a pytorch_format operation, so a 2nd permute is necessary to convert it back to pytorch format.
2841
"""
2942

3043
layout_sensitive_ops = {
@@ -76,7 +89,6 @@ class LayoutTransform(ExportPass):
7689
exir_ops.edge.aten.logical_not.default,
7790
exir_ops.edge.aten.lt.Scalar,
7891
exir_ops.edge.aten.lt.Tensor,
79-
exir_ops.edge.aten._log_softmax.default,
8092
exir_ops.edge.aten.maximum.default,
8193
exir_ops.edge.aten.mean.dim,
8294
exir_ops.edge.aten.minimum.default,
@@ -88,7 +100,6 @@ class LayoutTransform(ExportPass):
88100
exir_ops.edge.aten.prelu.default,
89101
exir_ops.edge.aten.repeat.default,
90102
exir_ops.edge.aten.relu.default,
91-
exir_ops.edge.aten._softmax.default, # TODO: Need to find a new solution to do "axis_order" to transform axis.
92103
exir_ops.edge.aten.sigmoid.default,
93104
exir_ops.edge.aten.split_with_sizes.default,
94105
exir_ops.edge.aten.split_with_sizes_copy.default,
@@ -282,11 +293,29 @@ def check_arg(arg):
282293
else:
283294
check_arg(args)
284295

296+
def conditional_sensitive_check(self, node):
297+
# For softmax and log_softmax, we must ensure axis == -1 since thats the only axis supported by QNN.
298+
# Softmax and log_softmax is treated as pytorch_layout in default, and will be treated as sensitive_layout when axis is not given as last dim.
299+
target_nodes = [
300+
exir_ops.edge.aten._softmax.default,
301+
exir_ops.edge.aten._log_softmax.default,
302+
]
303+
if node.target in target_nodes:
304+
dim = node.args[1]
305+
if dim < 0:
306+
dim = dim % node.meta["val"].dim()
307+
if dim != node.meta["val"].dim() - 1:
308+
return True
309+
return False
310+
285311
def call(self, graph_module: torch.fx.GraphModule):
286312
graph = graph_module.graph
287313
sensitive_nodes = [
288-
node for node in graph.nodes if self.is_layout_sensitive(node)
314+
node
315+
for node in graph.nodes
316+
if self.is_layout_sensitive(node) or self.conditional_sensitive_check(node)
289317
]
318+
290319
# perform first run traversal for identifying nodes subjected to layout changes
291320
if self.insert_permute:
292321
self.insert_permute, self.transformed_tag = False, QCOM_LAYOUT_CHANGE

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ConvertSquareToPow,
1919
DecomposeAny,
2020
DecomposeCDist,
21+
DecomposeColIm,
2122
DecomposeEinsum,
2223
DecomposeExpM1,
2324
DecomposeLinalgVectorNorm,
@@ -82,6 +83,7 @@ def get_capture_program_passes():
8283
(ConvertBmmToMatmul, True),
8384
(ConvertConv1dToConv2d, True),
8485
(DecomposeAny, True),
86+
(DecomposeColIm, True),
8587
(ExpandBroadcastTensorShape, False),
8688
(FixedLinearKeepDim, True),
8789
(FoldQDQ, True),

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def get_passes_dependency_for_capture_program():
6767
ConvertBmmToMatmul,
6868
ConvertConv1dToConv2d,
6969
DecomposeAny,
70+
DecomposeColIm,
7071
DecomposeLinalgVectorNorm,
7172
ExpandBroadcastTensorShape,
7273
FixedLinearKeepDim,
@@ -91,6 +92,7 @@ def get_passes_dependency_for_capture_program():
9192
AnnotateUnbind: [RemoveRedundancy],
9293
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
9394
DecomposeAny: [RemoveRedundancy],
95+
DecomposeColIm: [FoldQDQ],
9496
DecomposeLinalgVectorNorm: [RemoveRedundancy],
9597
ExpandBroadcastTensorShape: [FoldQDQ],
9698
FixedLinearKeepDim: [FoldQDQ],

backends/qualcomm/builders/op_index_put.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def define_node(
4141
indices_qnn = torch.cat(indices_unpacked).unsqueeze(0)
4242
indice_node = [n for n in indicies_node if isinstance(n, torch.fx.Node)]
4343
# TODO consider to write a pass to combine to one input tensor for indices
44-
assert len(indice_node) == 1, "Not support mutilple indices tensor"
44+
assert len(indice_node) == 1, "Not support multiple indices tensor"
4545

4646
indices_tensor_wrapper = self.define_tensor(
4747
indice_node[0],

backends/qualcomm/builders/op_log_softmax.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import cast, Dict
78

89
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9-
1010
import numpy as np
1111
import torch
1212
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
@@ -58,6 +58,10 @@ def define_node(
5858

5959
# logsoftmax only supports last dimension for now, which is channel in QNN
6060
if dim != input_tensor.dim() - 1:
61+
warnings.warn(
62+
"[QNN Delegate Op Builder]: LogSoftmax only supports channel axis.",
63+
stacklevel=1,
64+
)
6165
return None
6266

6367
log_softmax_op = PyQnnWrapper.PyQnnOpWrapper(

backends/qualcomm/builders/op_softmax.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import cast, Dict
78

89
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9-
1010
import numpy as np
1111
import torch
1212
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
@@ -53,9 +53,12 @@ def define_node(
5353
dim = dim % len(input_tensor.shape)
5454
if QCOM_AXIS_ORDER in node.meta:
5555
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
56-
5756
# softmax only supports last dimension for now, which is channel in QNN
5857
if dim != input_tensor.dim() - 1:
58+
warnings.warn(
59+
"[QNN Delegate Op Builder]: Softmax only supports channel axis.",
60+
stacklevel=1,
61+
)
5962
return None
6063

6164
softmax_op = PyQnnWrapper.PyQnnOpWrapper(

backends/qualcomm/partition/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@ def filter_fn(node: torch.fx.Node) -> bool:
4242
def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
4343
do_not_decompose = [
4444
torch.ops.aten.adaptive_avg_pool2d.default,
45+
torch.ops.aten.col2im.default,
4546
torch.ops.aten.elu.default,
4647
torch.ops.aten.hardsigmoid.default,
4748
torch.ops.aten.hardswish.default,
49+
torch.ops.aten.im2col.default,
4850
torch.ops.aten.instance_norm.default,
4951
torch.ops.aten.leaky_relu.default,
5052
torch.ops.aten.linear.default,

backends/qualcomm/quantizer/annotators.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,11 @@ def annotate_cos(node: Node, quantization_config: QuantizationConfig) -> None:
399399
annotate_single_in_single_out(node, quantization_config)
400400

401401

402+
@register_annotator([torch.ops.aten.col2im.default, torch.ops.aten.im2col.default])
403+
def annotate_col_im(node: Node, quantization_config: QuantizationConfig) -> None:
404+
annotate_single_in_single_out(node, quantization_config)
405+
406+
402407
@register_annotator([torch.ops.aten.sin.default])
403408
def annotate_sin(node: Node, quantization_config: QuantizationConfig) -> None:
404409
annotate_single_in_single_out(node, quantization_config)
@@ -508,7 +513,13 @@ def annotate_prelu(node: Node, quantization_config: QuantizationConfig) -> None:
508513
annotate_binary(node, quantization_config)
509514

510515

511-
@register_annotator([torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default])
516+
@register_annotator(
517+
[
518+
torch.ops.aten.view_copy.default,
519+
torch.ops.aten.view.default,
520+
torch.ops.aten._unsafe_view.default,
521+
]
522+
)
512523
def annotate_view(node: Node, quantization_config: QuantizationConfig) -> None:
513524
annotate_in_out_obs_sharing_op(node, quantization_config)
514525
if not _is_annotated([node]):

0 commit comments

Comments
 (0)