Skip to content

Commit d554db7

Browse files
committed
fix hifigan init bug
1 parent d1f7c1c commit d554db7

File tree

4 files changed

+96
-55
lines changed

4 files changed

+96
-55
lines changed

cosyvoice/bin/train.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def get_args():
6868
action='store_true',
6969
default=False,
7070
help='Use pinned memory buffers used for reading')
71+
parser.add_argument('--use_amp',
72+
action='store_true',
73+
default=False,
74+
help='Use automatic mixed precision training')
7175
parser.add_argument('--deepspeed.save_states',
7276
dest='save_states',
7377
default='model_only',
@@ -133,6 +137,9 @@ def main():
133137
# Get executor
134138
executor = Executor(gan=gan)
135139

140+
# Init scaler, used for pytorch amp mixed precision training
141+
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
142+
136143
# Start training loop
137144
for epoch in range(info_dict['max_epoch']):
138145
executor.epoch = epoch
@@ -141,9 +148,9 @@ def main():
141148
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
142149
if gan is True:
143150
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
144-
writer, info_dict, group_join)
151+
writer, info_dict, scaler, group_join)
145152
else:
146-
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
153+
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
147154
dist.destroy_process_group(group_join)
148155

149156

cosyvoice/utils/executor.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, gan: bool = False):
3232
self.rank = int(os.environ.get('RANK', 0))
3333
self.device = torch.device('cuda:{}'.format(self.rank))
3434

35-
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join):
35+
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
3636
''' Train one epoch
3737
'''
3838

@@ -65,10 +65,10 @@ def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data
6565
context = nullcontext
6666

6767
with context():
68-
info_dict = batch_forward(model, batch_dict, info_dict)
69-
info_dict = batch_backward(model, info_dict)
68+
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
69+
info_dict = batch_backward(model, scaler, info_dict)
7070

71-
info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
71+
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
7272
log_per_step(writer, info_dict)
7373
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
7474
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
@@ -82,7 +82,7 @@ def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data
8282
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
8383

8484
def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
85-
writer, info_dict, group_join):
85+
writer, info_dict, scaler, group_join):
8686
''' Train one epoch
8787
'''
8888

@@ -116,16 +116,16 @@ def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler
116116

117117
with context():
118118
batch_dict['turn'] = 'discriminator'
119-
info_dict = batch_forward(model, batch_dict, info_dict)
120-
info_dict = batch_backward(model, info_dict)
121-
info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, info_dict)
119+
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
120+
info_dict = batch_backward(model, scaler, info_dict)
121+
info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
122122
optimizer.zero_grad()
123123
log_per_step(writer, info_dict)
124124
with context():
125125
batch_dict['turn'] = 'generator'
126-
info_dict = batch_forward(model, batch_dict, info_dict)
127-
info_dict = batch_backward(model, info_dict)
128-
info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
126+
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
127+
info_dict = batch_backward(model, scaler, info_dict)
128+
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
129129
optimizer_d.zero_grad()
130130
log_per_step(writer, info_dict)
131131
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
@@ -157,7 +157,7 @@ def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
157157

158158
if self.gan is True:
159159
batch_dict['turn'] = 'generator'
160-
info_dict = batch_forward(model, batch_dict, info_dict)
160+
info_dict = batch_forward(model, batch_dict, None, info_dict)
161161

162162
for k, v in info_dict['loss_dict'].items():
163163
if k not in total_loss_dict:

cosyvoice/utils/train_utils.py

+74-41
Original file line numberDiff line numberDiff line change
@@ -110,38 +110,60 @@ def wrap_cuda_model(args, model):
110110

111111

112112
def init_optimizer_and_scheduler(args, configs, model, gan):
113-
if configs['train_conf']['optim'] == 'adam':
114-
optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
115-
elif configs['train_conf']['optim'] == 'adamw':
116-
optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
117-
else:
118-
raise ValueError("unknown optimizer: " + configs['train_conf'])
119-
120-
if configs['train_conf']['scheduler'] == 'warmuplr':
121-
scheduler_type = WarmupLR
122-
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
123-
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
124-
scheduler_type = NoamHoldAnnealing
125-
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
126-
elif configs['train_conf']['scheduler'] == 'constantlr':
127-
scheduler_type = ConstantLR
128-
scheduler = ConstantLR(optimizer)
113+
if gan is False:
114+
if configs['train_conf']['optim'] == 'adam':
115+
optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
116+
elif configs['train_conf']['optim'] == 'adamw':
117+
optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
118+
else:
119+
raise ValueError("unknown optimizer: " + configs['train_conf'])
120+
121+
if configs['train_conf']['scheduler'] == 'warmuplr':
122+
scheduler_type = WarmupLR
123+
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
124+
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
125+
scheduler_type = NoamHoldAnnealing
126+
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
127+
elif configs['train_conf']['scheduler'] == 'constantlr':
128+
scheduler_type = ConstantLR
129+
scheduler = ConstantLR(optimizer)
130+
else:
131+
raise ValueError("unknown scheduler: " + configs['train_conf'])
132+
133+
# use deepspeed optimizer for speedup
134+
if args.train_engine == "deepspeed":
135+
def scheduler(opt):
136+
return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
137+
model, optimizer, _, scheduler = deepspeed.initialize(
138+
args=args,
139+
model=model,
140+
optimizer=None,
141+
lr_scheduler=scheduler,
142+
model_parameters=model.parameters())
143+
144+
optimizer_d, scheduler_d = None, None
145+
129146
else:
130-
raise ValueError("unknown scheduler: " + configs['train_conf'])
131-
132-
# use deepspeed optimizer for speedup
133-
if args.train_engine == "deepspeed":
134-
def scheduler(opt):
135-
return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
136-
model, optimizer, _, scheduler = deepspeed.initialize(
137-
args=args,
138-
model=model,
139-
optimizer=None,
140-
lr_scheduler=scheduler,
141-
model_parameters=model.parameters())
142-
143-
# currently we wrap generator and discriminator in one model, so we cannot use deepspeed
144-
if gan is True:
147+
# currently we wrap generator and discriminator in one model, so we cannot use deepspeed
148+
if configs['train_conf']['optim'] == 'adam':
149+
optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
150+
elif configs['train_conf']['optim'] == 'adamw':
151+
optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
152+
else:
153+
raise ValueError("unknown optimizer: " + configs['train_conf'])
154+
155+
if configs['train_conf']['scheduler'] == 'warmuplr':
156+
scheduler_type = WarmupLR
157+
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
158+
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
159+
scheduler_type = NoamHoldAnnealing
160+
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
161+
elif configs['train_conf']['scheduler'] == 'constantlr':
162+
scheduler_type = ConstantLR
163+
scheduler = ConstantLR(optimizer)
164+
else:
165+
raise ValueError("unknown scheduler: " + configs['train_conf'])
166+
145167
if configs['train_conf']['optim_d'] == 'adam':
146168
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
147169
elif configs['train_conf']['optim_d'] == 'adamw':
@@ -160,8 +182,6 @@ def scheduler(opt):
160182
scheduler_d = ConstantLR(optimizer_d)
161183
else:
162184
raise ValueError("unknown scheduler: " + configs['train_conf'])
163-
else:
164-
optimizer_d, scheduler_d = None, None
165185
return model, optimizer, scheduler, optimizer_d, scheduler_d
166186

167187

@@ -216,7 +236,7 @@ def cosyvoice_join(group_join, info_dict):
216236
return False
217237

218238

219-
def batch_forward(model, batch, info_dict):
239+
def batch_forward(model, batch, scaler, info_dict):
220240
device = int(os.environ.get('LOCAL_RANK', 0))
221241

222242
dtype = info_dict["dtype"]
@@ -228,7 +248,7 @@ def batch_forward(model, batch, info_dict):
228248
dtype = torch.float32
229249

230250
if info_dict['train_engine'] == 'torch_ddp':
231-
autocast = nullcontext()
251+
autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
232252
else:
233253
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
234254

@@ -237,27 +257,40 @@ def batch_forward(model, batch, info_dict):
237257
return info_dict
238258

239259

240-
def batch_backward(model, info_dict):
260+
def batch_backward(model, scaler, info_dict):
241261
if info_dict["train_engine"] == "deepspeed":
242262
scaled_loss = model.backward(info_dict['loss_dict']['loss'])
243263
else:
244264
scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
245-
scaled_loss.backward()
265+
if scaler is not None:
266+
scaler.scale(scaled_loss).backward()
267+
else:
268+
scaled_loss.backward()
246269

247270
info_dict['loss_dict']['loss'] = scaled_loss
248271
return info_dict
249272

250273

251-
def update_parameter_and_lr(model, optimizer, scheduler, info_dict):
274+
def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
252275
grad_norm = 0.0
253276
if info_dict['train_engine'] == "deepspeed":
254277
info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
255278
model.step()
256279
grad_norm = model.get_global_grad_norm()
257280
elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
258-
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
259-
if torch.isfinite(grad_norm):
260-
optimizer.step()
281+
# Use mixed precision training
282+
if scaler is not None:
283+
scaler.unscale_(optimizer)
284+
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
285+
# We don't check grad here since that if the gradient
286+
# has inf/nan values, scaler.step will skip
287+
# optimizer.step().
288+
scaler.step(optimizer)
289+
scaler.update()
290+
else:
291+
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
292+
if torch.isfinite(grad_norm):
293+
optimizer.step()
261294
optimizer.zero_grad()
262295
scheduler.step()
263296
info_dict["lr"] = optimizer.param_groups[0]['lr']

examples/libritts/cosyvoice/run.sh

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
9999
--num_workers ${num_workers} \
100100
--prefetch ${prefetch} \
101101
--pin_memory \
102+
--use_amp \
102103
--deepspeed_config ./conf/ds_stage2.json \
103104
--deepspeed.save_states model+optimizer
104105
done

0 commit comments

Comments
 (0)