14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
17
- from contextlib import nullcontext
18
17
import logging
19
18
import os
20
19
import torch
@@ -110,38 +109,60 @@ def wrap_cuda_model(args, model):
110
109
111
110
112
111
def init_optimizer_and_scheduler (args , configs , model , gan ):
113
- if configs ['train_conf' ]['optim' ] == 'adam' :
114
- optimizer = optim .Adam (model .parameters (), ** configs ['train_conf' ]['optim_conf' ])
115
- elif configs ['train_conf' ]['optim' ] == 'adamw' :
116
- optimizer = optim .AdamW (model .parameters (), ** configs ['train_conf' ]['optim_conf' ])
117
- else :
118
- raise ValueError ("unknown optimizer: " + configs ['train_conf' ])
119
-
120
- if configs ['train_conf' ]['scheduler' ] == 'warmuplr' :
121
- scheduler_type = WarmupLR
122
- scheduler = WarmupLR (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
123
- elif configs ['train_conf' ]['scheduler' ] == 'NoamHoldAnnealing' :
124
- scheduler_type = NoamHoldAnnealing
125
- scheduler = NoamHoldAnnealing (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
126
- elif configs ['train_conf' ]['scheduler' ] == 'constantlr' :
127
- scheduler_type = ConstantLR
128
- scheduler = ConstantLR (optimizer )
112
+ if gan is False :
113
+ if configs ['train_conf' ]['optim' ] == 'adam' :
114
+ optimizer = optim .Adam (model .parameters (), ** configs ['train_conf' ]['optim_conf' ])
115
+ elif configs ['train_conf' ]['optim' ] == 'adamw' :
116
+ optimizer = optim .AdamW (model .parameters (), ** configs ['train_conf' ]['optim_conf' ])
117
+ else :
118
+ raise ValueError ("unknown optimizer: " + configs ['train_conf' ])
119
+
120
+ if configs ['train_conf' ]['scheduler' ] == 'warmuplr' :
121
+ scheduler_type = WarmupLR
122
+ scheduler = WarmupLR (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
123
+ elif configs ['train_conf' ]['scheduler' ] == 'NoamHoldAnnealing' :
124
+ scheduler_type = NoamHoldAnnealing
125
+ scheduler = NoamHoldAnnealing (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
126
+ elif configs ['train_conf' ]['scheduler' ] == 'constantlr' :
127
+ scheduler_type = ConstantLR
128
+ scheduler = ConstantLR (optimizer )
129
+ else :
130
+ raise ValueError ("unknown scheduler: " + configs ['train_conf' ])
131
+
132
+ # use deepspeed optimizer for speedup
133
+ if args .train_engine == "deepspeed" :
134
+ def scheduler (opt ):
135
+ return scheduler_type (opt , ** configs ['train_conf' ]['scheduler_conf' ])
136
+ model , optimizer , _ , scheduler = deepspeed .initialize (
137
+ args = args ,
138
+ model = model ,
139
+ optimizer = None ,
140
+ lr_scheduler = scheduler ,
141
+ model_parameters = model .parameters ())
142
+
143
+ optimizer_d , scheduler_d = None , None
144
+
129
145
else :
130
- raise ValueError ("unknown scheduler: " + configs ['train_conf' ])
131
-
132
- # use deepspeed optimizer for speedup
133
- if args .train_engine == "deepspeed" :
134
- def scheduler (opt ):
135
- return scheduler_type (opt , ** configs ['train_conf' ]['scheduler_conf' ])
136
- model , optimizer , _ , scheduler = deepspeed .initialize (
137
- args = args ,
138
- model = model ,
139
- optimizer = None ,
140
- lr_scheduler = scheduler ,
141
- model_parameters = model .parameters ())
142
-
143
- # currently we wrap generator and discriminator in one model, so we cannot use deepspeed
144
- if gan is True :
146
+ # currently we wrap generator and discriminator in one model, so we cannot use deepspeed
147
+ if configs ['train_conf' ]['optim' ] == 'adam' :
148
+ optimizer = optim .Adam (model .module .generator .parameters (), ** configs ['train_conf' ]['optim_conf' ])
149
+ elif configs ['train_conf' ]['optim' ] == 'adamw' :
150
+ optimizer = optim .AdamW (model .module .generator .parameters (), ** configs ['train_conf' ]['optim_conf' ])
151
+ else :
152
+ raise ValueError ("unknown optimizer: " + configs ['train_conf' ])
153
+
154
+ if configs ['train_conf' ]['scheduler' ] == 'warmuplr' :
155
+ scheduler_type = WarmupLR
156
+ scheduler = WarmupLR (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
157
+ elif configs ['train_conf' ]['scheduler' ] == 'NoamHoldAnnealing' :
158
+ scheduler_type = NoamHoldAnnealing
159
+ scheduler = NoamHoldAnnealing (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
160
+ elif configs ['train_conf' ]['scheduler' ] == 'constantlr' :
161
+ scheduler_type = ConstantLR
162
+ scheduler = ConstantLR (optimizer )
163
+ else :
164
+ raise ValueError ("unknown scheduler: " + configs ['train_conf' ])
165
+
145
166
if configs ['train_conf' ]['optim_d' ] == 'adam' :
146
167
optimizer_d = optim .Adam (model .module .discriminator .parameters (), ** configs ['train_conf' ]['optim_conf' ])
147
168
elif configs ['train_conf' ]['optim_d' ] == 'adamw' :
@@ -160,8 +181,6 @@ def scheduler(opt):
160
181
scheduler_d = ConstantLR (optimizer_d )
161
182
else :
162
183
raise ValueError ("unknown scheduler: " + configs ['train_conf' ])
163
- else :
164
- optimizer_d , scheduler_d = None , None
165
184
return model , optimizer , scheduler , optimizer_d , scheduler_d
166
185
167
186
@@ -216,7 +235,7 @@ def cosyvoice_join(group_join, info_dict):
216
235
return False
217
236
218
237
219
- def batch_forward (model , batch , info_dict ):
238
+ def batch_forward (model , batch , scaler , info_dict ):
220
239
device = int (os .environ .get ('LOCAL_RANK' , 0 ))
221
240
222
241
dtype = info_dict ["dtype" ]
@@ -228,7 +247,7 @@ def batch_forward(model, batch, info_dict):
228
247
dtype = torch .float32
229
248
230
249
if info_dict ['train_engine' ] == 'torch_ddp' :
231
- autocast = nullcontext ( )
250
+ autocast = torch . cuda . amp . autocast ( enabled = scaler is not None )
232
251
else :
233
252
autocast = torch .cuda .amp .autocast (enabled = True , dtype = dtype , cache_enabled = False )
234
253
@@ -237,27 +256,40 @@ def batch_forward(model, batch, info_dict):
237
256
return info_dict
238
257
239
258
240
- def batch_backward (model , info_dict ):
259
+ def batch_backward (model , scaler , info_dict ):
241
260
if info_dict ["train_engine" ] == "deepspeed" :
242
261
scaled_loss = model .backward (info_dict ['loss_dict' ]['loss' ])
243
262
else :
244
263
scaled_loss = info_dict ['loss_dict' ]['loss' ] / info_dict ['accum_grad' ]
245
- scaled_loss .backward ()
264
+ if scaler is not None :
265
+ scaler .scale (scaled_loss ).backward ()
266
+ else :
267
+ scaled_loss .backward ()
246
268
247
269
info_dict ['loss_dict' ]['loss' ] = scaled_loss
248
270
return info_dict
249
271
250
272
251
- def update_parameter_and_lr (model , optimizer , scheduler , info_dict ):
273
+ def update_parameter_and_lr (model , optimizer , scheduler , scaler , info_dict ):
252
274
grad_norm = 0.0
253
275
if info_dict ['train_engine' ] == "deepspeed" :
254
276
info_dict ["is_gradient_accumulation_boundary" ] = model .is_gradient_accumulation_boundary ()
255
277
model .step ()
256
278
grad_norm = model .get_global_grad_norm ()
257
279
elif (info_dict ['batch_idx' ] + 1 ) % info_dict ["accum_grad" ] == 0 :
258
- grad_norm = clip_grad_norm_ (model .parameters (), info_dict ['grad_clip' ])
259
- if torch .isfinite (grad_norm ):
260
- optimizer .step ()
280
+ # Use mixed precision training
281
+ if scaler is not None :
282
+ scaler .unscale_ (optimizer )
283
+ grad_norm = clip_grad_norm_ (model .parameters (), info_dict ['grad_clip' ])
284
+ # We don't check grad here since that if the gradient
285
+ # has inf/nan values, scaler.step will skip
286
+ # optimizer.step().
287
+ scaler .step (optimizer )
288
+ scaler .update ()
289
+ else :
290
+ grad_norm = clip_grad_norm_ (model .parameters (), info_dict ['grad_clip' ])
291
+ if torch .isfinite (grad_norm ):
292
+ optimizer .step ()
261
293
optimizer .zero_grad ()
262
294
scheduler .step ()
263
295
info_dict ["lr" ] = optimizer .param_groups [0 ]['lr' ]
0 commit comments