Skip to content

Commit 634edfa

Browse files
authored
Merge pull request #983 from Shengqiang-Li/main
feat: Support DPO
2 parents f0b8e89 + a22873e commit 634edfa

File tree

12 files changed

+2156
-3
lines changed

12 files changed

+2156
-3
lines changed

Diff for: cosyvoice/bin/train_dpo.py

+187
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
import argparse
17+
import datetime
18+
import logging
19+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
20+
from copy import deepcopy
21+
import os
22+
import torch
23+
import torch.distributed as dist
24+
import deepspeed
25+
26+
from hyperpyyaml import load_hyperpyyaml
27+
28+
from torch.distributed.elastic.multiprocessing.errors import record
29+
30+
from cosyvoice.utils.executor_dpo import Executor
31+
from cosyvoice.utils.train_utils_dpo import (
32+
init_distributed,
33+
init_dataset_and_dataloader,
34+
init_optimizer_and_scheduler,
35+
init_summarywriter, save_model,
36+
wrap_cuda_model, check_modify_and_save_config)
37+
38+
39+
def get_args():
40+
parser = argparse.ArgumentParser(description='training your network')
41+
parser.add_argument('--train_engine',
42+
default='torch_ddp',
43+
choices=['torch_ddp', 'deepspeed'],
44+
help='Engine for paralleled training')
45+
parser.add_argument('--model', required=True, help='model which will be trained')
46+
parser.add_argument('--config', required=True, help='config file')
47+
parser.add_argument('--train_data', required=True, help='train data file')
48+
parser.add_argument('--cv_data', required=True, help='cv data file')
49+
parser.add_argument('--checkpoint', help='checkpoint model')
50+
parser.add_argument('--model_dir', required=True, help='save model dir')
51+
parser.add_argument('--tensorboard_dir',
52+
default='tensorboard',
53+
help='tensorboard log dir')
54+
parser.add_argument('--ddp.dist_backend',
55+
dest='dist_backend',
56+
default='nccl',
57+
choices=['nccl', 'gloo'],
58+
help='distributed backend')
59+
parser.add_argument('--num_workers',
60+
default=0,
61+
type=int,
62+
help='num of subprocess workers for reading')
63+
parser.add_argument('--prefetch',
64+
default=100,
65+
type=int,
66+
help='prefetch number')
67+
parser.add_argument('--pin_memory',
68+
action='store_true',
69+
default=False,
70+
help='Use pinned memory buffers used for reading')
71+
parser.add_argument('--use_amp',
72+
action='store_true',
73+
default=False,
74+
help='Use automatic mixed precision training')
75+
parser.add_argument('--deepspeed.save_states',
76+
dest='save_states',
77+
default='model_only',
78+
choices=['model_only', 'model+optimizer'],
79+
help='save model/optimizer states')
80+
parser.add_argument('--timeout',
81+
default=60,
82+
type=int,
83+
help='timeout (in seconds) of cosyvoice_join.')
84+
parser.add_argument('--dpo',
85+
action='store_true',
86+
default=False,
87+
help='Use Direct Preference Optimization')
88+
parser.add_argument('--beta',
89+
default=0.01,
90+
type=float,
91+
help='beta of dpo training')
92+
parser = deepspeed.add_config_arguments(parser)
93+
args = parser.parse_args()
94+
return args
95+
96+
97+
@record
98+
def main():
99+
args = get_args()
100+
logging.basicConfig(level=logging.DEBUG,
101+
format='%(asctime)s %(levelname)s %(message)s')
102+
# gan train has some special initialization logic
103+
gan = True if args.model == 'hifigan' else False
104+
105+
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
106+
if gan is True:
107+
override_dict.pop('hift')
108+
with open(args.config, 'r') as f:
109+
configs = load_hyperpyyaml(f, overrides=override_dict)
110+
if gan is True:
111+
configs['train_conf'] = configs['train_conf_gan']
112+
configs['train_conf'].update(vars(args))
113+
114+
# Init env for ddp
115+
init_distributed(args)
116+
117+
# Get dataset & dataloader
118+
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
119+
init_dataset_and_dataloader(args, configs, gan)
120+
121+
# Do some sanity checks and save config to arsg.model_dir
122+
configs = check_modify_and_save_config(args, configs)
123+
124+
# Tensorboard summary
125+
writer = init_summarywriter(args)
126+
127+
# load checkpoint
128+
model = configs[args.model]
129+
ref_model = None
130+
if args.dpo:
131+
ref_model = deepcopy(model)
132+
start_step, start_epoch = 0, -1
133+
if args.checkpoint is not None:
134+
if os.path.exists(args.checkpoint):
135+
state_dict = torch.load(args.checkpoint, map_location='cpu')
136+
model.load_state_dict(state_dict, strict=False)
137+
if args.dpo:
138+
ref_model.load_state_dict(state_dict, strict=False)
139+
if 'step' in state_dict:
140+
start_step = state_dict['step']
141+
if 'epoch' in state_dict:
142+
start_epoch = state_dict['epoch']
143+
else:
144+
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
145+
146+
# Dispatch model from cpu to gpu
147+
model = wrap_cuda_model(args, model)
148+
if args.dpo:
149+
ref_model = wrap_cuda_model(args, ref_model)
150+
151+
# Get optimizer & scheduler
152+
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
153+
if args.dpo:
154+
ref_model, _, _, _, _ = init_optimizer_and_scheduler(args, configs, ref_model, gan)
155+
scheduler.set_step(start_step)
156+
if scheduler_d is not None:
157+
scheduler_d.set_step(start_step)
158+
159+
# Save init checkpoints
160+
info_dict = deepcopy(configs['train_conf'])
161+
info_dict['step'] = start_step
162+
info_dict['epoch'] = start_epoch
163+
save_model(model, 'init', info_dict)
164+
165+
# Get executor
166+
executor = Executor(gan=gan, dpo=args.dpo, beta=args.beta)
167+
executor.step = start_step
168+
169+
# Init scaler, used for pytorch amp mixed precision training
170+
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
171+
print('start step {} start epoch {}'.format(start_step, start_epoch))
172+
# Start training loop
173+
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
174+
executor.epoch = epoch
175+
train_dataset.set_epoch(epoch)
176+
dist.barrier()
177+
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
178+
if gan is True:
179+
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
180+
writer, info_dict, scaler, group_join)
181+
else:
182+
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model)
183+
dist.destroy_process_group(group_join)
184+
185+
186+
if __name__ == '__main__':
187+
main()

0 commit comments

Comments
 (0)