Skip to content

Commit deb63ec

Browse files
authored
Milestone3.1: Support tanh op in XNNPACK backend (#11364)
### Summary Support tanh in XNNPACK backend ### Test plan Wrote test cases to see if appropriate xnnpack tanh was called
1 parent 439aab7 commit deb63ec

File tree

10 files changed

+145
-0
lines changed

10 files changed

+145
-0
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,6 @@
5151
op_static_constant_pad,
5252
op_static_resize_bilinear_2d,
5353
op_sub,
54+
op_tanh,
5455
op_to_copy,
5556
)

backends/xnnpack/operators/op_tanh.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNGraph,
16+
XNNTanh,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
@register_node_visitor
23+
class TanhVisitor(NodeVisitor):
24+
target = "aten.tanh.default"
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
xnn_graph: XNNGraph,
33+
vals_to_ids: Dict[torch.fx.Node, int],
34+
debug_handle: int,
35+
) -> None:
36+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
37+
38+
# input
39+
input_id = vals_to_ids[get_input_node(node, 0)]
40+
41+
# output
42+
output_id = vals_to_ids[node]
43+
44+
ser_node = XNode(
45+
xnode_union=XNNTanh(
46+
input_id=input_id,
47+
output_id=output_id,
48+
flags=0,
49+
),
50+
debug_handle=debug_handle,
51+
)
52+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
SoftmaxConfig,
5151
SquareRootConfig,
5252
SubConfig,
53+
TanhConfig,
5354
UpsampleBilinear2dConfig,
5455
)
5556
from executorch.backends.xnnpack.partition.config.node_configs import (
@@ -101,6 +102,7 @@
101102
PreluConfig,
102103
ReciprocalSquareRootConfig,
103104
ReLUConfig,
105+
TanhConfig,
104106
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
105107
SigmoidConfig,
106108
SliceCopyConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
378378
return [ConfigPrecisionType.FP32]
379379

380380

381+
class TanhConfig(GenericNodePartitionerConfig):
382+
target_name = "tanh.default"
383+
384+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
385+
return [ConfigPrecisionType.FP32]
386+
387+
381388
class MeanDimConfig(GenericNodePartitionerConfig):
382389
target_name = "mean.dim"
383390

backends/xnnpack/partition/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
exir_ops.edge.aten.rsqrt.default,
6868
exir_ops.edge.aten.log.default,
6969
exir_ops.edge.aten.gelu.default,
70+
exir_ops.edge.aten.tanh.default,
7071
]
7172

7273
SUPPORTED_MODULES = [

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,6 +1513,36 @@ Error defineGeluNode(
15131513
return Error::Ok;
15141514
}
15151515

1516+
/*
1517+
Define serialized tanh node into the subgraph, using the remapped ids
1518+
to map the serialized ids, to the new ids generated when defining the
1519+
tensor value
1520+
*/
1521+
Error defineTanhNode(
1522+
xnn_subgraph_t subgraph_ptr,
1523+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1524+
const NodePtr node,
1525+
const fb_xnnpack::XNNGraph* graph) noexcept {
1526+
MAYBE_UNUSED(graph);
1527+
1528+
auto graph_node = node->xnode_union_as_XNNTanh();
1529+
1530+
xnn_status status = xnn_define_tanh(
1531+
subgraph_ptr,
1532+
remapped_ids.at(graph_node->input_id()),
1533+
remapped_ids.at(graph_node->output_id()),
1534+
graph_node->flags());
1535+
1536+
ET_CHECK_OR_RETURN_ERROR(
1537+
status == xnn_status_success,
1538+
Internal,
1539+
"Failed to create tanh node %i with code: %s",
1540+
node->debug_handle(),
1541+
xnn_status_to_string(status));
1542+
1543+
return Error::Ok;
1544+
}
1545+
15161546
/*
15171547
Define serialized ceiling node into the subgraph, using the remapped ids
15181548
to map the serialized ids, to the new ids generated when defining the
@@ -2108,6 +2138,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
21082138
_DEFINE(Hardswish)
21092139
_DEFINE(LeakyReLU)
21102140
_DEFINE(Log)
2141+
_DEFINE(Tanh)
21112142
_DEFINE(Maximum)
21122143
_DEFINE(Negate)
21132144
_DEFINE(Square)

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ union XNodeUnion {
146146
XNNReciprocalSquareRoot: _XNNNode1x1,
147147
XNNLog: _XNNNode1x1,
148148
XNNGelu: _XNNNode1x1,
149+
XNNTanh: _XNNNode1x1,
149150
}
150151

151152
union XValueUnion {

backends/xnnpack/serialization/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ union XNodeUnion {
142142
XNNReciprocalSquareRoot: _XNNNode1x1,
143143
XNNLog: _XNNNode1x1,
144144
XNNGelu: _XNNNode1x1,
145+
XNNTanh: _XNNNode1x1,
145146
}
146147

147148
union XValueUnion {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,11 @@ class XNNLog(XNNNode1x1):
324324
pass
325325

326326

327+
@dataclass
328+
class XNNTanh(XNNNode1x1):
329+
pass
330+
331+
327332
@dataclass
328333
class XNNMaximum(XNNNode2x1):
329334
pass
@@ -396,6 +401,7 @@ class XNNScaledDotProductAttention:
396401
XNNReciprocalSquareRoot,
397402
XNNLog,
398403
XNNGelu,
404+
XNNTanh,
399405
]
400406

401407

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Tester
11+
12+
13+
class TestTanh(unittest.TestCase):
14+
def setUp(self):
15+
torch._dynamo.reset()
16+
17+
class Tanh(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, x):
22+
return torch.tanh(x)
23+
24+
def run_tanh_test(self, inputs):
25+
(
26+
Tester(self.Tanh(), inputs)
27+
.export()
28+
.check_count({"torch.ops.aten.tanh.default": 1})
29+
.to_edge_transform_and_lower()
30+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
31+
.check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"])
32+
.to_executorch()
33+
.serialize()
34+
.run_method_and_compare_outputs()
35+
)
36+
37+
def test_fp16_tanh(self):
38+
inputs = (torch.randn(20).to(torch.float16),)
39+
self.run_tanh_test(inputs)
40+
41+
def test_fp32_tanh(self):
42+
inputs = (torch.randn(20),)
43+
self.run_tanh_test(inputs)

0 commit comments

Comments
 (0)