File tree Expand file tree Collapse file tree 10 files changed +145
-0
lines changed Expand file tree Collapse file tree 10 files changed +145
-0
lines changed Original file line number Diff line number Diff line change 51
51
op_static_constant_pad ,
52
52
op_static_resize_bilinear_2d ,
53
53
op_sub ,
54
+ op_tanh ,
54
55
op_to_copy ,
55
56
)
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Dict
8
+
9
+ import torch
10
+ from executorch .backends .xnnpack .operators .node_visitor import (
11
+ NodeVisitor ,
12
+ register_node_visitor ,
13
+ )
14
+ from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
15
+ XNNGraph ,
16
+ XNNTanh ,
17
+ XNode ,
18
+ )
19
+ from executorch .backends .xnnpack .utils .utils import get_input_node
20
+
21
+
22
+ @register_node_visitor
23
+ class TanhVisitor (NodeVisitor ):
24
+ target = "aten.tanh.default"
25
+
26
+ def __init__ (self , * args ) -> None :
27
+ super ().__init__ (* args )
28
+
29
+ def define_node (
30
+ self ,
31
+ node : torch .fx .Node ,
32
+ xnn_graph : XNNGraph ,
33
+ vals_to_ids : Dict [torch .fx .Node , int ],
34
+ debug_handle : int ,
35
+ ) -> None :
36
+ self .define_nodes_tensor_inputs_outputs (node , xnn_graph , vals_to_ids )
37
+
38
+ # input
39
+ input_id = vals_to_ids [get_input_node (node , 0 )]
40
+
41
+ # output
42
+ output_id = vals_to_ids [node ]
43
+
44
+ ser_node = XNode (
45
+ xnode_union = XNNTanh (
46
+ input_id = input_id ,
47
+ output_id = output_id ,
48
+ flags = 0 ,
49
+ ),
50
+ debug_handle = debug_handle ,
51
+ )
52
+ xnn_graph .xnodes .append (ser_node )
Original file line number Diff line number Diff line change 50
50
SoftmaxConfig ,
51
51
SquareRootConfig ,
52
52
SubConfig ,
53
+ TanhConfig ,
53
54
UpsampleBilinear2dConfig ,
54
55
)
55
56
from executorch .backends .xnnpack .partition .config .node_configs import (
101
102
PreluConfig ,
102
103
ReciprocalSquareRootConfig ,
103
104
ReLUConfig ,
105
+ TanhConfig ,
104
106
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
105
107
SigmoidConfig ,
106
108
SliceCopyConfig ,
Original file line number Diff line number Diff line change @@ -378,6 +378,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
378
378
return [ConfigPrecisionType .FP32 ]
379
379
380
380
381
+ class TanhConfig (GenericNodePartitionerConfig ):
382
+ target_name = "tanh.default"
383
+
384
+ def supported_precision_types (self ) -> List [ConfigPrecisionType ]:
385
+ return [ConfigPrecisionType .FP32 ]
386
+
387
+
381
388
class MeanDimConfig (GenericNodePartitionerConfig ):
382
389
target_name = "mean.dim"
383
390
Original file line number Diff line number Diff line change 67
67
exir_ops .edge .aten .rsqrt .default ,
68
68
exir_ops .edge .aten .log .default ,
69
69
exir_ops .edge .aten .gelu .default ,
70
+ exir_ops .edge .aten .tanh .default ,
70
71
]
71
72
72
73
SUPPORTED_MODULES = [
Original file line number Diff line number Diff line change @@ -1513,6 +1513,36 @@ Error defineGeluNode(
1513
1513
return Error::Ok;
1514
1514
}
1515
1515
1516
+ /*
1517
+ Define serialized tanh node into the subgraph, using the remapped ids
1518
+ to map the serialized ids, to the new ids generated when defining the
1519
+ tensor value
1520
+ */
1521
+ Error defineTanhNode (
1522
+ xnn_subgraph_t subgraph_ptr,
1523
+ const std::unordered_map<uint32_t , uint32_t >& remapped_ids,
1524
+ const NodePtr node,
1525
+ const fb_xnnpack::XNNGraph* graph) noexcept {
1526
+ MAYBE_UNUSED (graph);
1527
+
1528
+ auto graph_node = node->xnode_union_as_XNNTanh ();
1529
+
1530
+ xnn_status status = xnn_define_tanh (
1531
+ subgraph_ptr,
1532
+ remapped_ids.at (graph_node->input_id ()),
1533
+ remapped_ids.at (graph_node->output_id ()),
1534
+ graph_node->flags ());
1535
+
1536
+ ET_CHECK_OR_RETURN_ERROR (
1537
+ status == xnn_status_success,
1538
+ Internal,
1539
+ " Failed to create tanh node %i with code: %s" ,
1540
+ node->debug_handle (),
1541
+ xnn_status_to_string (status));
1542
+
1543
+ return Error::Ok;
1544
+ }
1545
+
1516
1546
/*
1517
1547
Define serialized ceiling node into the subgraph, using the remapped ids
1518
1548
to map the serialized ids, to the new ids generated when defining the
@@ -2108,6 +2138,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
2108
2138
_DEFINE (Hardswish)
2109
2139
_DEFINE (LeakyReLU)
2110
2140
_DEFINE (Log)
2141
+ _DEFINE (Tanh)
2111
2142
_DEFINE (Maximum)
2112
2143
_DEFINE (Negate)
2113
2144
_DEFINE (Square)
Original file line number Diff line number Diff line change @@ -146,6 +146,7 @@ union XNodeUnion {
146
146
XNNReciprocalSquareRoot: _XNNNode1x1,
147
147
XNNLog: _XNNNode1x1,
148
148
XNNGelu: _XNNNode1x1,
149
+ XNNTanh: _XNNNode1x1,
149
150
}
150
151
151
152
union XValueUnion {
Original file line number Diff line number Diff line change @@ -142,6 +142,7 @@ union XNodeUnion {
142
142
XNNReciprocalSquareRoot: _XNNNode1x1,
143
143
XNNLog: _XNNNode1x1,
144
144
XNNGelu: _XNNNode1x1,
145
+ XNNTanh: _XNNNode1x1,
145
146
}
146
147
147
148
union XValueUnion {
Original file line number Diff line number Diff line change @@ -324,6 +324,11 @@ class XNNLog(XNNNode1x1):
324
324
pass
325
325
326
326
327
+ @dataclass
328
+ class XNNTanh (XNNNode1x1 ):
329
+ pass
330
+
331
+
327
332
@dataclass
328
333
class XNNMaximum (XNNNode2x1 ):
329
334
pass
@@ -396,6 +401,7 @@ class XNNScaledDotProductAttention:
396
401
XNNReciprocalSquareRoot ,
397
402
XNNLog ,
398
403
XNNGelu ,
404
+ XNNTanh ,
399
405
]
400
406
401
407
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import unittest
8
+
9
+ import torch
10
+ from executorch .backends .xnnpack .test .tester import Tester
11
+
12
+
13
+ class TestTanh (unittest .TestCase ):
14
+ def setUp (self ):
15
+ torch ._dynamo .reset ()
16
+
17
+ class Tanh (torch .nn .Module ):
18
+ def __init__ (self ):
19
+ super ().__init__ ()
20
+
21
+ def forward (self , x ):
22
+ return torch .tanh (x )
23
+
24
+ def run_tanh_test (self , inputs ):
25
+ (
26
+ Tester (self .Tanh (), inputs )
27
+ .export ()
28
+ .check_count ({"torch.ops.aten.tanh.default" : 1 })
29
+ .to_edge_transform_and_lower ()
30
+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
31
+ .check_not (["executorch_exir_dialects_edge__ops_aten_tanh_default" ])
32
+ .to_executorch ()
33
+ .serialize ()
34
+ .run_method_and_compare_outputs ()
35
+ )
36
+
37
+ def test_fp16_tanh (self ):
38
+ inputs = (torch .randn (20 ).to (torch .float16 ),)
39
+ self .run_tanh_test (inputs )
40
+
41
+ def test_fp32_tanh (self ):
42
+ inputs = (torch .randn (20 ),)
43
+ self .run_tanh_test (inputs )
You can’t perform that action at this time.
0 commit comments