@@ -110,30 +110,29 @@ def wrap_cuda_model(args, model):
110
110
111
111
112
112
def init_optimizer_and_scheduler (args , configs , model , gan ):
113
- key = 'train_conf_gan' if gan is True else 'train_conf'
114
- if configs [key ]['optim' ] == 'adam' :
115
- optimizer = optim .Adam (model .parameters (), ** configs [key ]['optim_conf' ])
116
- elif configs [key ]['optim' ] == 'adamw' :
117
- optimizer = optim .AdamW (model .parameters (), ** configs [key ]['optim_conf' ])
113
+ if configs ['train_conf' ]['optim' ] == 'adam' :
114
+ optimizer = optim .Adam (model .parameters (), ** configs ['train_conf' ]['optim_conf' ])
115
+ elif configs ['train_conf' ]['optim' ] == 'adamw' :
116
+ optimizer = optim .AdamW (model .parameters (), ** configs ['train_conf' ]['optim_conf' ])
118
117
else :
119
- raise ValueError ("unknown optimizer: " + configs [key ])
118
+ raise ValueError ("unknown optimizer: " + configs ['train_conf' ])
120
119
121
- if configs [key ]['scheduler' ] == 'warmuplr' :
120
+ if configs ['train_conf' ]['scheduler' ] == 'warmuplr' :
122
121
scheduler_type = WarmupLR
123
- scheduler = WarmupLR (optimizer , ** configs [key ]['scheduler_conf' ])
124
- elif configs [key ]['scheduler' ] == 'NoamHoldAnnealing' :
122
+ scheduler = WarmupLR (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
123
+ elif configs ['train_conf' ]['scheduler' ] == 'NoamHoldAnnealing' :
125
124
scheduler_type = NoamHoldAnnealing
126
- scheduler = NoamHoldAnnealing (optimizer , ** configs [key ]['scheduler_conf' ])
127
- elif configs [key ]['scheduler' ] == 'constantlr' :
125
+ scheduler = NoamHoldAnnealing (optimizer , ** configs ['train_conf' ]['scheduler_conf' ])
126
+ elif configs ['train_conf' ]['scheduler' ] == 'constantlr' :
128
127
scheduler_type = ConstantLR
129
128
scheduler = ConstantLR (optimizer )
130
129
else :
131
- raise ValueError ("unknown scheduler: " + configs [key ])
130
+ raise ValueError ("unknown scheduler: " + configs ['train_conf' ])
132
131
133
132
# use deepspeed optimizer for speedup
134
133
if args .train_engine == "deepspeed" :
135
134
def scheduler (opt ):
136
- return scheduler_type (opt , ** configs [key ]['scheduler_conf' ])
135
+ return scheduler_type (opt , ** configs ['train_conf' ]['scheduler_conf' ])
137
136
model , optimizer , _ , scheduler = deepspeed .initialize (
138
137
args = args ,
139
138
model = model ,
@@ -143,24 +142,24 @@ def scheduler(opt):
143
142
144
143
# currently we wrap generator and discriminator in one model, so we cannot use deepspeed
145
144
if gan is True :
146
- if configs [key ]['optim_d' ] == 'adam' :
147
- optimizer_d = optim .Adam (model .module .discriminator .parameters (), ** configs [key ]['optim_conf' ])
148
- elif configs [key ]['optim_d' ] == 'adamw' :
149
- optimizer_d = optim .AdamW (model .module .discriminator .parameters (), ** configs [key ]['optim_conf' ])
145
+ if configs ['train_conf' ]['optim_d' ] == 'adam' :
146
+ optimizer_d = optim .Adam (model .module .discriminator .parameters (), ** configs ['train_conf' ]['optim_conf' ])
147
+ elif configs ['train_conf' ]['optim_d' ] == 'adamw' :
148
+ optimizer_d = optim .AdamW (model .module .discriminator .parameters (), ** configs ['train_conf' ]['optim_conf' ])
150
149
else :
151
- raise ValueError ("unknown optimizer: " + configs [key ])
150
+ raise ValueError ("unknown optimizer: " + configs ['train_conf' ])
152
151
153
- if configs [key ]['scheduler_d' ] == 'warmuplr' :
152
+ if configs ['train_conf' ]['scheduler_d' ] == 'warmuplr' :
154
153
scheduler_type = WarmupLR
155
- scheduler_d = WarmupLR (optimizer_d , ** configs [key ]['scheduler_conf' ])
156
- elif configs [key ]['scheduler_d' ] == 'NoamHoldAnnealing' :
154
+ scheduler_d = WarmupLR (optimizer_d , ** configs ['train_conf' ]['scheduler_conf' ])
155
+ elif configs ['train_conf' ]['scheduler_d' ] == 'NoamHoldAnnealing' :
157
156
scheduler_type = NoamHoldAnnealing
158
- scheduler_d = NoamHoldAnnealing (optimizer_d , ** configs [key ]['scheduler_conf' ])
159
- elif configs [key ]['scheduler' ] == 'constantlr' :
157
+ scheduler_d = NoamHoldAnnealing (optimizer_d , ** configs ['train_conf' ]['scheduler_conf' ])
158
+ elif configs ['train_conf' ]['scheduler' ] == 'constantlr' :
160
159
scheduler_type = ConstantLR
161
160
scheduler_d = ConstantLR (optimizer_d )
162
161
else :
163
- raise ValueError ("unknown scheduler: " + configs [key ])
162
+ raise ValueError ("unknown scheduler: " + configs ['train_conf' ])
164
163
else :
165
164
optimizer_d , scheduler_d = None , None
166
165
return model , optimizer , scheduler , optimizer_d , scheduler_d
0 commit comments