Skip to content

Commit 50f7376

Browse files
authored
Arm backend: Support for dynamic shapes and fix resize bugs (#11310)
Dynamic shapes are detected when converting torch shapes to TOSA shapes, when they contain SymInts rather than regular ints. These are converted to -1s as this is how TOSA represents dynamic shapes. The resize op needs special handling for dynamic shapes, and will only work when all of the TOSA parameters (scale n/d, offset, border) work out to be constant values independent of the shape. Also fix bug where align_corners was always set to True, when it should be False Signed-off-by: Richard Burton <richard.burton@arm.com> Co-authored-by: Robert Hughes @Rob-Hughes-Arm
1 parent aed9c7e commit 50f7376

File tree

5 files changed

+295
-69
lines changed

5 files changed

+295
-69
lines changed

backends/arm/operators/op_upsample_bilinear2d.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,18 @@ def define_node(
4949
input_dtype = inputs[0].dtype
5050

5151
# tosa_shape output is NHWC, take HW
52-
input_size_yx = torch.tensor(
53-
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
54-
)
55-
# Ignore scale and size parameters, directly use the output size as
56-
# we only support static shapes currently
57-
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])
52+
input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[
53+
1:3
54+
]
55+
output_size_yx = tuple([output.shape[dim] for dim in output.dim_order])[1:3]
5856

57+
# Get align_corners value from the node arguments.
58+
align_corners = bool(node.args[2])
5959
scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
60-
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
60+
input_size_yx,
61+
output_size_yx,
62+
ResizeMode.NEAREST,
63+
align_corners=align_corners,
6164
)
6265

6366
def in_int16_range(x):
@@ -139,15 +142,18 @@ def define_node(
139142
input_dtype = inputs[0].dtype
140143

141144
# tosa_shape output is NHWC, take HW
142-
input_size_yx = torch.tensor(
143-
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
144-
)
145-
# Ignore scale and size parameters, directly use the output size as
146-
# we only support static shapes currently
147-
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])
145+
input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[
146+
1:3
147+
]
148+
output_size_yx = tuple([output.shape[dim] for dim in output.dim_order])[1:3]
148149

150+
# Get align_corners value from the node arguments.
151+
align_corners = bool(node.args[2])
149152
scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
150-
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
153+
input_size_yx,
154+
output_size_yx,
155+
ResizeMode.NEAREST,
156+
align_corners=align_corners,
151157
)
152158

153159
def in_int16_range(x):

backends/arm/operators/op_upsample_nearest2d.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
validate_same_dtype,
1818
)
1919
from executorch.backends.arm.tosa_mapping import TosaArg
20-
from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape
20+
from executorch.backends.arm.tosa_utils import get_resize_parameters
2121

2222
from tosa_tools.v0_80.tosa.ResizeMode import ResizeMode # type: ignore
2323

@@ -43,19 +43,16 @@ def define_node(
4343
validate_num_inputs(self.target, inputs, 3)
4444
validate_same_dtype(self.target, [inputs[0], output])
4545

46-
if inputs[0].shape is None or output.shape is None:
47-
raise ValueError("Only static shapes are supported")
48-
4946
# tosa_shape output is NHWC, take HW
50-
input_size_yx = torch.tensor(
51-
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
52-
)
53-
# Ignore scale and size parameters, directly use the output size as
54-
# we only support static shapes currently
55-
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])
47+
input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[
48+
1:3
49+
]
50+
output_size_yx = tuple([output.shape[dim] for dim in output.dim_order])[1:3]
5651

52+
# Align corners shouldn't make a difference for nearest upsampling. We set to False so
53+
# half pixel centers are used for resize parameter logic.
5754
scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
58-
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
55+
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=False
5956
)
6057

6158
def in_int16_range(x):
@@ -102,19 +99,16 @@ def define_node(
10299
validate_num_inputs(self.target, inputs, 3)
103100
validate_same_dtype(self.target, [inputs[0], output])
104101

105-
if inputs[0].shape is None or output.shape is None:
106-
raise ValueError("Only static shapes are supported")
107-
108102
# tosa_shape output is NHWC, take HW
109-
input_size_yx = torch.tensor(
110-
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
111-
)
112-
# Ignore scale and size parameters, directly use the output size as
113-
# we only support static shapes currently
114-
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])
103+
input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[
104+
1:3
105+
]
106+
output_size_yx = tuple([output.shape[dim] for dim in output.dim_order])[1:3]
115107

108+
# Align corners shouldn't make a difference for nearest upsampling. We set to False so
109+
# half pixel centers are used for resize parameter logic.
116110
scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
117-
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
111+
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=False
118112
)
119113

120114
def in_int16_range(x):

backends/arm/test/ops/test_upsample_nearest2d.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@
4040
"rand_one_and_half_size": lambda: (torch.rand(2, 4, 8, 3), (12, 4), None, False),
4141
}
4242

43+
test_data_suite_dynamic = {
44+
# (test_name, test_data, size, scale_factor, compare_outputs)
45+
"rand_double_scale": lambda: (torch.rand(2, 4, 8, 3), None, 2.0, False),
46+
"rand_double_scale_one_dim": lambda: (
47+
torch.rand(2, 4, 8, 3),
48+
None,
49+
(1.0, 2.0),
50+
False,
51+
),
52+
}
53+
4354

4455
class UpsamplingNearest2d(torch.nn.Module):
4556
def __init__(
@@ -161,3 +172,159 @@ def test_upsample_nearest2d_vec_tosa_BI_nearest(test_data: torch.Tensor):
161172
pipeline.pop_stage(-1)
162173

163174
pipeline.run()
175+
176+
177+
@common.parametrize("test_data", test_data_suite_dynamic)
178+
def test_upsample_nearest2d_dynamic_MI_nearest(test_data: torch.Tensor):
179+
test_data, size, scale_factor, compare_outputs = test_data()
180+
181+
batch_size = torch.export.Dim("batch", min=0, max=1000)
182+
input_height = torch.export.Dim("input_height", min=0, max=1000)
183+
input_width = torch.export.Dim("input_width", min=0, max=1000)
184+
185+
dynamic_shapes = {"x": {0: batch_size, 2: input_height, 3: input_width}}
186+
187+
pipeline = TosaPipelineMI[input_t1](
188+
UpsamplingNearest2d(size, scale_factor),
189+
(test_data,),
190+
aten_op,
191+
exir_op=[],
192+
dynamic_shapes=dynamic_shapes,
193+
)
194+
if not compare_outputs:
195+
pipeline.pop_stage(-1)
196+
pipeline.run()
197+
198+
199+
@common.parametrize("test_data", test_data_suite_dynamic)
200+
def test_upsample_nearest2d_dynamic_BI_nearest(test_data: torch.Tensor):
201+
test_data, size, scale_factor, compare_outputs = test_data()
202+
203+
batch_size = torch.export.Dim("batch", min=0, max=2)
204+
input_height = torch.export.Dim("input_height", min=0, max=8)
205+
input_width = torch.export.Dim("input_width", min=0, max=8)
206+
207+
dynamic_shapes = {"x": {0: batch_size, 2: input_height, 3: input_width}}
208+
209+
pipeline = TosaPipelineBI[input_t1](
210+
UpsamplingNearest2d(size, scale_factor),
211+
(test_data,),
212+
aten_op,
213+
exir_op=[],
214+
dynamic_shapes=dynamic_shapes,
215+
)
216+
if not compare_outputs:
217+
pipeline.pop_stage(-1)
218+
pipeline.run()
219+
220+
221+
@common.parametrize("test_data", test_data_suite_dynamic)
222+
def test_upsample_nearest2d_dynamic_MI_interpolate(test_data: torch.Tensor):
223+
test_data, size, scale_factor, compare_outputs = test_data()
224+
225+
batch_size = torch.export.Dim("batch", min=0, max=2)
226+
input_height = torch.export.Dim("input_height", min=4, max=8)
227+
input_width = torch.export.Dim("input_width", min=3, max=8)
228+
229+
dynamic_shapes = {
230+
"x": {
231+
0: batch_size,
232+
2: input_height,
233+
3: input_width,
234+
}
235+
}
236+
237+
pipeline = TosaPipelineMI[input_t1](
238+
Interpolate(size, scale_factor),
239+
(test_data,),
240+
aten_op,
241+
exir_op=[],
242+
dynamic_shapes=dynamic_shapes,
243+
)
244+
if not compare_outputs:
245+
pipeline.pop_stage(-1)
246+
pipeline.run()
247+
248+
249+
@common.parametrize("test_data", test_data_suite_dynamic)
250+
def test_upsample_nearest2d_dynamic_BI_interpolate(test_data: torch.Tensor):
251+
test_data, size, scale_factor, compare_outputs = test_data()
252+
253+
batch_size = torch.export.Dim("batch", min=0, max=2)
254+
input_height = torch.export.Dim("input_height", min=4, max=8)
255+
input_width = torch.export.Dim("input_width", min=3, max=8)
256+
257+
dynamic_shapes = {
258+
"x": {
259+
0: batch_size,
260+
2: input_height,
261+
3: input_width,
262+
}
263+
}
264+
265+
pipeline = TosaPipelineBI[input_t1](
266+
Interpolate(size, scale_factor),
267+
(test_data,),
268+
aten_op,
269+
exir_op=[],
270+
dynamic_shapes=dynamic_shapes,
271+
)
272+
if not compare_outputs:
273+
pipeline.pop_stage(-1)
274+
pipeline.run()
275+
276+
277+
@common.parametrize("test_data", test_data_suite_dynamic)
278+
def test_upsample_nearest2d_dynamic_MI_upsample(test_data: torch.Tensor):
279+
test_data, size, scale_factor, compare_outputs = test_data()
280+
281+
batch_size = torch.export.Dim("batch", min=0, max=1000)
282+
input_height = torch.export.Dim("input_height", min=0, max=1000)
283+
input_width = torch.export.Dim("input_width", min=0, max=1000)
284+
285+
dynamic_shapes = {
286+
"x": {
287+
0: batch_size,
288+
2: input_height,
289+
3: input_width,
290+
}
291+
}
292+
293+
pipeline = TosaPipelineMI[input_t1](
294+
Upsample(size, scale_factor),
295+
(test_data,),
296+
aten_op,
297+
exir_op=[],
298+
dynamic_shapes=dynamic_shapes,
299+
)
300+
if not compare_outputs:
301+
pipeline.pop_stage(-1)
302+
pipeline.run()
303+
304+
305+
@common.parametrize("test_data", test_data_suite_dynamic)
306+
def test_upsample_nearest2d_dynamic_BI_upsample(test_data: torch.Tensor):
307+
test_data, size, scale_factor, compare_outputs = test_data()
308+
309+
batch_size = torch.export.Dim("batch", min=0, max=2)
310+
input_height = torch.export.Dim("input_height", min=0, max=8)
311+
input_width = torch.export.Dim("input_width", min=0, max=8)
312+
313+
dynamic_shapes = {
314+
"x": {
315+
0: batch_size,
316+
2: input_height,
317+
3: input_width,
318+
}
319+
}
320+
321+
pipeline = TosaPipelineBI[input_t1](
322+
Upsample(size, scale_factor),
323+
(test_data,),
324+
aten_op,
325+
exir_op=[],
326+
dynamic_shapes=dynamic_shapes,
327+
)
328+
if not compare_outputs:
329+
pipeline.pop_stage(-1)
330+
pipeline.run()

backends/arm/test/tester/test_pipeline.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7-
from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar
7+
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
88

99
import torch
1010

@@ -88,10 +88,14 @@ def __init__(
8888
compile_spec: List[CompileSpec],
8989
exir_ops: Optional[str | List[str]] = None,
9090
use_to_edge_transform_and_lower: bool = True,
91+
dynamic_shapes: Optional[Tuple[Any]] = None,
9192
):
9293

9394
self.tester = ArmTester(
94-
module, example_inputs=test_data, compile_spec=compile_spec
95+
module,
96+
example_inputs=test_data,
97+
compile_spec=compile_spec,
98+
dynamic_shapes=dynamic_shapes,
9599
)
96100

97101
self.aten_ops = aten_ops if isinstance(aten_ops, list) else [aten_ops]
@@ -283,6 +287,7 @@ def __init__(
283287
atol: float = 1e-03,
284288
rtol: float = 1e-03,
285289
qtol: int = 1,
290+
dynamic_shapes: Optional[Tuple[Any]] = None,
286291
):
287292
tosa_profiles = {
288293
"0.80": TosaSpecification.create_from_string("TOSA-0.80+BI"),
@@ -310,6 +315,7 @@ def __init__(
310315
compile_spec,
311316
exir_op,
312317
use_to_edge_transform_and_lower,
318+
dynamic_shapes,
313319
)
314320
self.add_stage(self.tester.quantize, quant_stage, pos=0)
315321

@@ -381,6 +387,7 @@ def __init__(
381387
atol: float = 1e-03,
382388
rtol: float = 1e-03,
383389
qtol: int = 0,
390+
dynamic_shapes: Optional[Tuple[Any]] = None,
384391
):
385392
tosa_profiles = {
386393
"0.80": TosaSpecification.create_from_string("TOSA-0.80+MI"),
@@ -398,6 +405,7 @@ def __init__(
398405
compile_spec,
399406
exir_op,
400407
use_to_edge_transform_and_lower,
408+
dynamic_shapes,
401409
)
402410
self.add_stage_after(
403411
"export",

0 commit comments

Comments
 (0)