Skip to content

Commit e42dafc

Browse files
authored
Qualcomm AI Engine Direct - GA CvT (#11036)
Summary: - Add CvT example script - Add the test for CvT - Fix missing quant config for torch.ops.aten.split_with_sizes.default
1 parent bd57234 commit e42dafc

File tree

2 files changed

+244
-0
lines changed

2 files changed

+244
-0
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3902,6 +3902,42 @@ def test_conv_former(self):
39023902
self.assertGreaterEqual(msg["top_1"], 60)
39033903
self.assertGreaterEqual(msg["top_5"], 80)
39043904

3905+
def test_cvt(self):
3906+
if not self.required_envs([self.image_dataset]):
3907+
self.skipTest("missing required envs")
3908+
3909+
cmds = [
3910+
"python",
3911+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/cvt.py",
3912+
"--dataset",
3913+
self.image_dataset,
3914+
"--artifact",
3915+
self.artifact_dir,
3916+
"--build_folder",
3917+
self.build_folder,
3918+
"--device",
3919+
self.device,
3920+
"--model",
3921+
self.model,
3922+
"--ip",
3923+
self.ip,
3924+
"--port",
3925+
str(self.port),
3926+
]
3927+
if self.host:
3928+
cmds.extend(["--host", self.host])
3929+
3930+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
3931+
with Listener((self.ip, self.port)) as listener:
3932+
conn = listener.accept()
3933+
p.communicate()
3934+
msg = json.loads(conn.recv())
3935+
if "Error" in msg:
3936+
self.fail(msg["Error"])
3937+
else:
3938+
self.assertGreaterEqual(msg["top_1"], 70)
3939+
self.assertGreaterEqual(msg["top_5"], 90)
3940+
39053941
def test_deit(self):
39063942
if not self.required_envs([self.image_dataset]):
39073943
self.skipTest("missing required envs")

examples/qualcomm/oss_scripts/cvt.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import logging
9+
import os
10+
import types
11+
from multiprocessing.connection import Client
12+
13+
import numpy as np
14+
15+
import torch
16+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
17+
from executorch.examples.qualcomm.utils import (
18+
build_executorch_binary,
19+
get_imagenet_dataset,
20+
make_output_dir,
21+
parse_skip_delegation_node,
22+
setup_common_args_and_variables,
23+
SimpleADB,
24+
topk_accuracy,
25+
)
26+
from transformers import AutoModelForImageClassification
27+
from transformers.models.cvt.modeling_cvt import CvtSelfAttention
28+
29+
30+
# Copy from transformers/models/cvt/modeling_cvt.py in transformers 4.47.1
31+
# torch.einsum("bhlk,bhtk->bhlt", [query, key]) will result in prepare failed due to 5D tensor with decompose_einsum.
32+
# TODO: once HTP fixed, this workaround can be removed
33+
def attention_forward_without_einsum(self, hidden_state, height, width):
34+
if self.with_cls_token:
35+
cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1)
36+
batch_size, hidden_size, num_channels = hidden_state.shape
37+
# rearrange "b (h w) c -> b c h w"
38+
hidden_state = hidden_state.permute(0, 2, 1).view(
39+
batch_size, num_channels, height, width
40+
)
41+
42+
key = self.convolution_projection_key(hidden_state)
43+
query = self.convolution_projection_query(hidden_state)
44+
value = self.convolution_projection_value(hidden_state)
45+
46+
if self.with_cls_token:
47+
query = torch.cat((cls_token, query), dim=1)
48+
key = torch.cat((cls_token, key), dim=1)
49+
value = torch.cat((cls_token, value), dim=1)
50+
51+
head_dim = self.embed_dim // self.num_heads
52+
53+
query = self.rearrange_for_multi_head_attention(self.projection_query(query))
54+
key = self.rearrange_for_multi_head_attention(self.projection_key(key))
55+
value = self.rearrange_for_multi_head_attention(self.projection_value(value))
56+
# ====================Qualcomm Changed=================================
57+
attention_score = query @ key.transpose(-1, -2)
58+
attention_score = attention_score * self.scale
59+
# attention_score = torch.einsum("bhlk,bhtk->bhlt", [query, key]) * self.scale
60+
# =====================================================================
61+
attention_probs = torch.nn.functional.softmax(attention_score, dim=-1)
62+
attention_probs = self.dropout(attention_probs)
63+
# ====================Qualcomm Changed=================================
64+
context = attention_probs @ value
65+
# context = torch.einsum("bhlt,bhtv->bhlv", [attention_probs, value])
66+
# =====================================================================
67+
# rearrange"b h t d -> b t (h d)"
68+
_, _, hidden_size, _ = context.shape
69+
context = (
70+
context.permute(0, 2, 1, 3)
71+
.contiguous()
72+
.view(batch_size, hidden_size, self.num_heads * head_dim)
73+
)
74+
return context
75+
76+
77+
def _replace_attention(
78+
module: torch.nn.Module,
79+
):
80+
for _, child in module.named_children():
81+
if isinstance(child, CvtSelfAttention):
82+
child.forward = types.MethodType( # pyre-ignore
83+
attention_forward_without_einsum, child
84+
)
85+
else:
86+
_replace_attention(child)
87+
return module
88+
89+
90+
def main(args):
91+
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
92+
93+
# ensure the working directory exist.
94+
os.makedirs(args.artifact, exist_ok=True)
95+
96+
if not args.compile_only and args.device is None:
97+
raise RuntimeError(
98+
"device serial is required if not compile only. "
99+
"Please specify a device serial by -s/--device argument."
100+
)
101+
102+
data_num = 100
103+
if args.ci:
104+
inputs = [(torch.rand(1, 3, 224, 224),)]
105+
logging.warning(
106+
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
107+
)
108+
else:
109+
inputs, targets, input_list = get_imagenet_dataset(
110+
dataset_path=f"{args.dataset}",
111+
data_size=data_num,
112+
image_shape=(256, 256),
113+
crop_size=224,
114+
)
115+
116+
module = (
117+
AutoModelForImageClassification.from_pretrained("microsoft/cvt-13")
118+
.eval()
119+
.to("cpu")
120+
)
121+
# Fix prepare failed due to einsum
122+
module = _replace_attention(module)
123+
pte_filename = "cvt_qnn_q8"
124+
build_executorch_binary(
125+
module.eval(),
126+
inputs[0],
127+
args.model,
128+
f"{args.artifact}/{pte_filename}",
129+
inputs,
130+
skip_node_id_set=skip_node_id_set,
131+
skip_node_op_set=skip_node_op_set,
132+
quant_dtype=QuantDtype.use_8a8w,
133+
shared_buffer=args.shared_buffer,
134+
)
135+
136+
if args.compile_only:
137+
return
138+
139+
adb = SimpleADB(
140+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
141+
build_path=f"{args.build_folder}",
142+
pte_path=f"{args.artifact}/{pte_filename}.pte",
143+
workspace=f"/data/local/tmp/executorch/{pte_filename}",
144+
device_id=args.device,
145+
host_id=args.host,
146+
soc_model=args.model,
147+
shared_buffer=args.shared_buffer,
148+
)
149+
adb.push(inputs=inputs, input_list=input_list)
150+
adb.execute()
151+
152+
# collect output data
153+
output_data_folder = f"{args.artifact}/outputs"
154+
make_output_dir(output_data_folder)
155+
156+
adb.pull(output_path=args.artifact)
157+
158+
# top-k analysis
159+
predictions = []
160+
for i in range(data_num):
161+
predictions.append(
162+
np.fromfile(
163+
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
164+
)
165+
)
166+
167+
k_val = [1, 5]
168+
topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
169+
if args.ip and args.port != -1:
170+
with Client((args.ip, args.port)) as conn:
171+
conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
172+
else:
173+
for i, k in enumerate(k_val):
174+
print(f"top_{k}->{topk[i]}%")
175+
176+
177+
if __name__ == "__main__":
178+
parser = setup_common_args_and_variables()
179+
180+
parser.add_argument(
181+
"-d",
182+
"--dataset",
183+
help=(
184+
"path to the validation folder of ImageNet dataset. "
185+
"e.g. --dataset imagenet-mini/val "
186+
"for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
187+
),
188+
type=str,
189+
required=False,
190+
)
191+
192+
parser.add_argument(
193+
"-a",
194+
"--artifact",
195+
help="path for storing generated artifacts by this example. " "Default ./cvt",
196+
default="./cvt",
197+
type=str,
198+
)
199+
200+
args = parser.parse_args()
201+
try:
202+
main(args)
203+
except Exception as e:
204+
if args.ip and args.port != -1:
205+
with Client((args.ip, args.port)) as conn:
206+
conn.send(json.dumps({"Error": str(e)}))
207+
else:
208+
raise Exception(e)

0 commit comments

Comments
 (0)