28
28
sys .path .append ('{}/../..' .format (ROOT_DIR ))
29
29
sys .path .append ('{}/../../third_party/Matcha-TTS' .format (ROOT_DIR ))
30
30
from cosyvoice .cli .cosyvoice import CosyVoice , CosyVoice2
31
+ from cosyvoice .utils .file_utils import logging
31
32
32
33
33
34
def get_dummy_input (batch_size , seq_len , out_channels , device ):
@@ -51,6 +52,7 @@ def get_args():
51
52
return args
52
53
53
54
55
+ @torch .no_grad ()
54
56
def main ():
55
57
args = get_args ()
56
58
logging .basicConfig (level = logging .DEBUG ,
@@ -60,56 +62,132 @@ def main():
60
62
model = CosyVoice (args .model_dir )
61
63
except Exception :
62
64
try :
63
- model = CosyVoice2 (args .model_dir )
65
+ # NOTE set use_flow_cache=True when export jit for cache inference
66
+ model = CosyVoice2 (args .model_dir , use_flow_cache = True )
64
67
except Exception :
65
68
raise TypeError ('no valid model_type!' )
66
69
67
- # 1. export flow decoder estimator
68
- estimator = model .model .flow .decoder .estimator
69
-
70
- device = model .model .device
71
- batch_size , seq_len = 2 , 256
72
- out_channels = model .model .flow .decoder .estimator .out_channels
73
- x , mask , mu , t , spks , cond = get_dummy_input (batch_size , seq_len , out_channels , device )
74
- torch .onnx .export (
75
- estimator ,
76
- (x , mask , mu , t , spks , cond ),
77
- '{}/flow.decoder.estimator.fp32.onnx' .format (args .model_dir ),
78
- export_params = True ,
79
- opset_version = 18 ,
80
- do_constant_folding = True ,
81
- input_names = ['x' , 'mask' , 'mu' , 't' , 'spks' , 'cond' ],
82
- output_names = ['estimator_out' ],
83
- dynamic_axes = {
84
- 'x' : {2 : 'seq_len' },
85
- 'mask' : {2 : 'seq_len' },
86
- 'mu' : {2 : 'seq_len' },
87
- 'cond' : {2 : 'seq_len' },
88
- 'estimator_out' : {2 : 'seq_len' },
89
- }
90
- )
91
-
92
- # 2. test computation consistency
93
- option = onnxruntime .SessionOptions ()
94
- option .graph_optimization_level = onnxruntime .GraphOptimizationLevel .ORT_ENABLE_ALL
95
- option .intra_op_num_threads = 1
96
- providers = ['CUDAExecutionProvider' if torch .cuda .is_available () else 'CPUExecutionProvider' ]
97
- estimator_onnx = onnxruntime .InferenceSession ('{}/flow.decoder.estimator.fp32.onnx' .format (args .model_dir ),
98
- sess_options = option , providers = providers )
99
-
100
- for _ in tqdm (range (10 )):
101
- x , mask , mu , t , spks , cond = get_dummy_input (batch_size , random .randint (16 , 512 ), out_channels , device )
102
- output_pytorch = estimator (x , mask , mu , t , spks , cond )
103
- ort_inputs = {
104
- 'x' : x .cpu ().numpy (),
105
- 'mask' : mask .cpu ().numpy (),
106
- 'mu' : mu .cpu ().numpy (),
107
- 't' : t .cpu ().numpy (),
108
- 'spks' : spks .cpu ().numpy (),
109
- 'cond' : cond .cpu ().numpy ()
110
- }
111
- output_onnx = estimator_onnx .run (None , ort_inputs )[0 ]
112
- torch .testing .assert_allclose (output_pytorch , torch .from_numpy (output_onnx ).to (device ), rtol = 1e-2 , atol = 1e-4 )
70
+ if not isinstance (model , CosyVoice2 ):
71
+ # 1. export flow decoder estimator
72
+ estimator = model .model .flow .decoder .estimator
73
+ estimator .eval ()
74
+
75
+ device = model .model .device
76
+ batch_size , seq_len = 2 , 256
77
+ out_channels = model .model .flow .decoder .estimator .out_channels
78
+ x , mask , mu , t , spks , cond = get_dummy_input (batch_size , seq_len , out_channels , device )
79
+ torch .onnx .export (
80
+ estimator ,
81
+ (x , mask , mu , t , spks , cond ),
82
+ '{}/flow.decoder.estimator.fp32.onnx' .format (args .model_dir ),
83
+ export_params = True ,
84
+ opset_version = 18 ,
85
+ do_constant_folding = True ,
86
+ input_names = ['x' , 'mask' , 'mu' , 't' , 'spks' , 'cond' ],
87
+ output_names = ['estimator_out' ],
88
+ dynamic_axes = {
89
+ 'x' : {2 : 'seq_len' },
90
+ 'mask' : {2 : 'seq_len' },
91
+ 'mu' : {2 : 'seq_len' },
92
+ 'cond' : {2 : 'seq_len' },
93
+ 'estimator_out' : {2 : 'seq_len' },
94
+ }
95
+ )
96
+
97
+ # 2. test computation consistency
98
+ option = onnxruntime .SessionOptions ()
99
+ option .graph_optimization_level = onnxruntime .GraphOptimizationLevel .ORT_ENABLE_ALL
100
+ option .intra_op_num_threads = 1
101
+ providers = ['CUDAExecutionProvider' if torch .cuda .is_available () else 'CPUExecutionProvider' ]
102
+ estimator_onnx = onnxruntime .InferenceSession ('{}/flow.decoder.estimator.fp32.onnx' .format (args .model_dir ),
103
+ sess_options = option , providers = providers )
104
+
105
+ for _ in tqdm (range (10 )):
106
+ x , mask , mu , t , spks , cond = get_dummy_input (batch_size , random .randint (16 , 512 ), out_channels , device )
107
+ output_pytorch = estimator (x , mask , mu , t , spks , cond )
108
+ ort_inputs = {
109
+ 'x' : x .cpu ().numpy (),
110
+ 'mask' : mask .cpu ().numpy (),
111
+ 'mu' : mu .cpu ().numpy (),
112
+ 't' : t .cpu ().numpy (),
113
+ 'spks' : spks .cpu ().numpy (),
114
+ 'cond' : cond .cpu ().numpy ()
115
+ }
116
+ output_onnx = estimator_onnx .run (None , ort_inputs )[0 ]
117
+ torch .testing .assert_allclose (output_pytorch , torch .from_numpy (output_onnx ).to (device ), rtol = 1e-2 , atol = 1e-4 )
118
+ logging .info ('successfully export estimator' )
119
+ else :
120
+ # 1. export flow decoder estimator
121
+ estimator = model .model .flow .decoder .estimator
122
+ estimator .forward = estimator .forward_chunk
123
+ estimator .eval ()
124
+
125
+ device = model .model .device
126
+ batch_size , seq_len = 2 , 256
127
+ out_channels = model .model .flow .decoder .estimator .out_channels
128
+ x , mask , mu , t , spks , cond = get_dummy_input (batch_size , seq_len , out_channels , device )
129
+ cache = model .model .init_flow_cache ()['decoder_cache' ]
130
+ cache .pop ('offset' )
131
+ cache = {k : v [0 ] for k , v in cache .items ()}
132
+ torch .onnx .export (
133
+ estimator ,
134
+ (x , mask , mu , t , spks , cond ,
135
+ cache ['down_blocks_conv_cache' ],
136
+ cache ['down_blocks_kv_cache' ],
137
+ cache ['mid_blocks_conv_cache' ],
138
+ cache ['mid_blocks_kv_cache' ],
139
+ cache ['up_blocks_conv_cache' ],
140
+ cache ['up_blocks_kv_cache' ],
141
+ cache ['final_blocks_conv_cache' ]),
142
+ '{}/flow.decoder.estimator.fp32.onnx' .format (args .model_dir ),
143
+ export_params = True ,
144
+ opset_version = 18 ,
145
+ do_constant_folding = True ,
146
+ input_names = ['x' , 'mask' , 'mu' , 't' , 'spks' , 'cond' , 'down_blocks_conv_cache' , 'down_blocks_kv_cache' , 'mid_blocks_conv_cache' , 'mid_blocks_kv_cache' ,
147
+ 'up_blocks_conv_cache' , 'up_blocks_kv_cache' , 'final_blocks_conv_cache' ],
148
+ output_names = ['estimator_out' , 'down_blocks_conv_cache_out' , 'down_blocks_kv_cache_out' , 'mid_blocks_conv_cache_out' , 'mid_blocks_kv_cache_out' ,
149
+ 'up_blocks_conv_cache_out' , 'up_blocks_kv_cache_out' , 'final_blocks_conv_cache_out' ],
150
+ dynamic_axes = {
151
+ 'x' : {2 : 'seq_len' },
152
+ 'mask' : {2 : 'seq_len' },
153
+ 'mu' : {2 : 'seq_len' },
154
+ 'cond' : {2 : 'seq_len' },
155
+ 'down_blocks_kv_cache' : {3 : 'cache_in_len' },
156
+ 'mid_blocks_kv_cache' : {3 : 'cache_in_len' },
157
+ 'up_blocks_kv_cache' : {3 : 'cache_in_len' },
158
+ 'estimator_out' : {2 : 'seq_len' },
159
+ 'down_blocks_kv_cache_out' : {3 : 'cache_out_len' },
160
+ 'mid_blocks_kv_cache_out' : {3 : 'cache_out_len' },
161
+ 'up_blocks_kv_cache_out' : {3 : 'cache_out_len' },
162
+ }
163
+ )
164
+
165
+ # 2. test computation consistency
166
+ option = onnxruntime .SessionOptions ()
167
+ option .graph_optimization_level = onnxruntime .GraphOptimizationLevel .ORT_ENABLE_ALL
168
+ option .intra_op_num_threads = 1
169
+ providers = ['CUDAExecutionProvider' if torch .cuda .is_available () else 'CPUExecutionProvider' ]
170
+ estimator_onnx = onnxruntime .InferenceSession ('{}/flow.decoder.estimator.fp32.onnx' .format (args .model_dir ),
171
+ sess_options = option , providers = providers )
172
+
173
+ for _ in tqdm (range (10 )):
174
+ x , mask , mu , t , spks , cond = get_dummy_input (batch_size , random .randint (16 , 512 ), out_channels , device )
175
+ cache = model .model .init_flow_cache ()['decoder_cache' ]
176
+ cache .pop ('offset' )
177
+ cache = {k : v [0 ] for k , v in cache .items ()}
178
+ output_pytorch = estimator (x , mask , mu , t , spks , cond , ** {k : v .clone () for k , v in cache .items ()})
179
+ ort_inputs = {
180
+ 'x' : x .cpu ().numpy (),
181
+ 'mask' : mask .cpu ().numpy (),
182
+ 'mu' : mu .cpu ().numpy (),
183
+ 't' : t .cpu ().numpy (),
184
+ 'spks' : spks .cpu ().numpy (),
185
+ 'cond' : cond .cpu ().numpy (),
186
+ }
187
+ output_onnx = estimator_onnx .run (None , {** ort_inputs , ** {k : v .clone ().cpu ().numpy () for k , v in cache .items ()}})
188
+ for i , j in zip (output_pytorch , output_onnx ):
189
+ torch .testing .assert_allclose (i , torch .from_numpy (j ).to (device ), rtol = 1e-2 , atol = 1e-4 )
190
+ logging .info ('successfully export estimator' )
113
191
114
192
115
193
if __name__ == "__main__" :
0 commit comments