@@ -110,38 +110,60 @@ def wrap_cuda_model(args, model):
110
110
111
111
112
112
def init_optimizer_and_scheduler (args , configs , model , gan ):
113
- if configs ['train_conf' ]['optim' ] == 'adam' :
114
- optimizer = optim .Adam (model .parameters (), ** configs ['train_conf' ]['optim_conf' ])
115
- elif configs ['train_conf' ]['optim' ] == 'adamw' :
116
- optimizer = optim .AdamW (model .parameters (), ** configs ['train_conf' ]['optim_conf' ])
117
- else :
118
- raise ValueError ("unknown optimizer: " + configs ['train_conf' ])
119
-
120
- if configs ['train_conf' ]['scheduler' ] == 'warmuplr' :
121
- scheduler_type = WarmupLR
122
- scheduler = WarmupLR (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
123
- elif configs ['train_conf' ]['scheduler' ] == 'NoamHoldAnnealing' :
124
- scheduler_type = NoamHoldAnnealing
125
- scheduler = NoamHoldAnnealing (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
126
- elif configs ['train_conf' ]['scheduler' ] == 'constantlr' :
127
- scheduler_type = ConstantLR
128
- scheduler = ConstantLR (optimizer )
113
+ if gan is False :
114
+ if configs ['train_conf' ]['optim' ] == 'adam' :
115
+ optimizer = optim .Adam (model .parameters (), ** configs ['train_conf' ]['optim_conf' ])
116
+ elif configs ['train_conf' ]['optim' ] == 'adamw' :
117
+ optimizer = optim .AdamW (model .parameters (), ** configs ['train_conf' ]['optim_conf' ])
118
+ else :
119
+ raise ValueError ("unknown optimizer: " + configs ['train_conf' ])
120
+
121
+ if configs ['train_conf' ]['scheduler' ] == 'warmuplr' :
122
+ scheduler_type = WarmupLR
123
+ scheduler = WarmupLR (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
124
+ elif configs ['train_conf' ]['scheduler' ] == 'NoamHoldAnnealing' :
125
+ scheduler_type = NoamHoldAnnealing
126
+ scheduler = NoamHoldAnnealing (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
127
+ elif configs ['train_conf' ]['scheduler' ] == 'constantlr' :
128
+ scheduler_type = ConstantLR
129
+ scheduler = ConstantLR (optimizer )
130
+ else :
131
+ raise ValueError ("unknown scheduler: " + configs ['train_conf' ])
132
+
133
+ # use deepspeed optimizer for speedup
134
+ if args .train_engine == "deepspeed" :
135
+ def scheduler (opt ):
136
+ return scheduler_type (opt , ** configs ['train_conf' ]['scheduler_conf' ])
137
+ model , optimizer , _ , scheduler = deepspeed .initialize (
138
+ args = args ,
139
+ model = model ,
140
+ optimizer = None ,
141
+ lr_scheduler = scheduler ,
142
+ model_parameters = model .parameters ())
143
+
144
+ optimizer_d , scheduler_d = None , None
145
+
129
146
else :
130
- raise ValueError ("unknown scheduler: " + configs ['train_conf' ])
131
-
132
- # use deepspeed optimizer for speedup
133
- if args .train_engine == "deepspeed" :
134
- def scheduler (opt ):
135
- return scheduler_type (opt , ** configs ['train_conf' ]['scheduler_conf' ])
136
- model , optimizer , _ , scheduler = deepspeed .initialize (
137
- args = args ,
138
- model = model ,
139
- optimizer = None ,
140
- lr_scheduler = scheduler ,
141
- model_parameters = model .parameters ())
142
-
143
- # currently we wrap generator and discriminator in one model, so we cannot use deepspeed
144
- if gan is True :
147
+ # currently we wrap generator and discriminator in one model, so we cannot use deepspeed
148
+ if configs ['train_conf' ]['optim' ] == 'adam' :
149
+ optimizer = optim .Adam (model .module .generator .parameters (), ** configs ['train_conf' ]['optim_conf' ])
150
+ elif configs ['train_conf' ]['optim' ] == 'adamw' :
151
+ optimizer = optim .AdamW (model .module .generator .parameters (), ** configs ['train_conf' ]['optim_conf' ])
152
+ else :
153
+ raise ValueError ("unknown optimizer: " + configs ['train_conf' ])
154
+
155
+ if configs ['train_conf' ]['scheduler' ] == 'warmuplr' :
156
+ scheduler_type = WarmupLR
157
+ scheduler = WarmupLR (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
158
+ elif configs ['train_conf' ]['scheduler' ] == 'NoamHoldAnnealing' :
159
+ scheduler_type = NoamHoldAnnealing
160
+ scheduler = NoamHoldAnnealing (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
161
+ elif configs ['train_conf' ]['scheduler' ] == 'constantlr' :
162
+ scheduler_type = ConstantLR
163
+ scheduler = ConstantLR (optimizer )
164
+ else :
165
+ raise ValueError ("unknown scheduler: " + configs ['train_conf' ])
166
+
145
167
if configs ['train_conf' ]['optim_d' ] == 'adam' :
146
168
optimizer_d = optim .Adam (model .module .discriminator .parameters (), ** configs ['train_conf' ]['optim_conf' ])
147
169
elif configs ['train_conf' ]['optim_d' ] == 'adamw' :
@@ -160,8 +182,6 @@ def scheduler(opt):
160
182
scheduler_d = ConstantLR (optimizer_d )
161
183
else :
162
184
raise ValueError ("unknown scheduler: " + configs ['train_conf' ])
163
- else :
164
- optimizer_d , scheduler_d = None , None
165
185
return model , optimizer , scheduler , optimizer_d , scheduler_d
166
186
167
187
@@ -216,7 +236,7 @@ def cosyvoice_join(group_join, info_dict):
216
236
return False
217
237
218
238
219
- def batch_forward (model , batch , info_dict ):
239
+ def batch_forward (model , batch , scaler , info_dict ):
220
240
device = int (os .environ .get ('LOCAL_RANK' , 0 ))
221
241
222
242
dtype = info_dict ["dtype" ]
@@ -228,7 +248,7 @@ def batch_forward(model, batch, info_dict):
228
248
dtype = torch .float32
229
249
230
250
if info_dict ['train_engine' ] == 'torch_ddp' :
231
- autocast = nullcontext ( )
251
+ autocast = torch . cuda . amp . autocast ( enabled = scaler is not None )
232
252
else :
233
253
autocast = torch .cuda .amp .autocast (enabled = True , dtype = dtype , cache_enabled = False )
234
254
@@ -237,27 +257,40 @@ def batch_forward(model, batch, info_dict):
237
257
return info_dict
238
258
239
259
240
- def batch_backward (model , info_dict ):
260
+ def batch_backward (model , scaler , info_dict ):
241
261
if info_dict ["train_engine" ] == "deepspeed" :
242
262
scaled_loss = model .backward (info_dict ['loss_dict' ]['loss' ])
243
263
else :
244
264
scaled_loss = info_dict ['loss_dict' ]['loss' ] / info_dict ['accum_grad' ]
245
- scaled_loss .backward ()
265
+ if scaler is not None :
266
+ scaler .scale (scaled_loss ).backward ()
267
+ else :
268
+ scaled_loss .backward ()
246
269
247
270
info_dict ['loss_dict' ]['loss' ] = scaled_loss
248
271
return info_dict
249
272
250
273
251
- def update_parameter_and_lr (model , optimizer , scheduler , info_dict ):
274
+ def update_parameter_and_lr (model , optimizer , scheduler , scaler , info_dict ):
252
275
grad_norm = 0.0
253
276
if info_dict ['train_engine' ] == "deepspeed" :
254
277
info_dict ["is_gradient_accumulation_boundary" ] = model .is_gradient_accumulation_boundary ()
255
278
model .step ()
256
279
grad_norm = model .get_global_grad_norm ()
257
280
elif (info_dict ['batch_idx' ] + 1 ) % info_dict ["accum_grad" ] == 0 :
258
- grad_norm = clip_grad_norm_ (model .parameters (), info_dict ['grad_clip' ])
259
- if torch .isfinite (grad_norm ):
260
- optimizer .step ()
281
+ # Use mixed precision training
282
+ if scaler is not None :
283
+ scaler .unscale_ (optimizer )
284
+ grad_norm = clip_grad_norm_ (model .parameters (), info_dict ['grad_clip' ])
285
+ # We don't check grad here since that if the gradient
286
+ # has inf/nan values, scaler.step will skip
287
+ # optimizer.step().
288
+ scaler .step (optimizer )
289
+ scaler .update ()
290
+ else :
291
+ grad_norm = clip_grad_norm_ (model .parameters (), info_dict ['grad_clip' ])
292
+ if torch .isfinite (grad_norm ):
293
+ optimizer .step ()
261
294
optimizer .zero_grad ()
262
295
scheduler .step ()
263
296
info_dict ["lr" ] = optimizer .param_groups [0 ]['lr' ]
0 commit comments