Skip to content

Commit 3a7d1ae

Browse files
authored
Merge branch 'main' into main
2 parents d9c5b55 + e2994fc commit 3a7d1ae

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

infer-web.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
22
import shutil
33
import sys
4-
import json # Mangio fork using json for preset saving
4+
import json # Mangio fork using json for preset saving
5+
import math
56

67
import signal
78

@@ -934,6 +935,23 @@ def change_f0(if_f0_3, sr2, version19, step2b, gpus6, gpu_info9, extraction_crep
934935
)
935936

936937

938+
global log_interval
939+
940+
941+
def set_log_interval(exp_dir, batch_size12):
942+
log_interval = 1
943+
944+
folder_path = os.path.join(exp_dir, "1_16k_wavs")
945+
946+
if os.path.exists(folder_path) and os.path.isdir(folder_path):
947+
wav_files = [f for f in os.listdir(folder_path) if f.endswith(".wav")]
948+
if wav_files:
949+
sample_size = len(wav_files)
950+
log_interval = math.ceil(sample_size / batch_size12)
951+
952+
return log_interval
953+
954+
937955
# but3.click(click_train,[exp_dir1,sr2,if_f0_3,save_epoch10,total_epoch11,batch_size12,if_save_latest13,pretrained_G14,pretrained_D15,gpus16])
938956
def click_train(
939957
exp_dir1,
@@ -960,6 +978,9 @@ def click_train(
960978
if version19 == "v1"
961979
else "%s/3_feature768" % (exp_dir)
962980
)
981+
982+
log_interval = set_log_interval(exp_dir, batch_size12)
983+
963984
if if_f0_3:
964985
f0_dir = "%s/2a_f0" % (exp_dir)
965986
f0nsf_dir = "%s/2b-f0nsf" % (exp_dir)
@@ -1029,7 +1050,7 @@ def click_train(
10291050
####
10301051
cmd = (
10311052
config.python_cmd
1032-
+ " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s"
1053+
+ " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s -li %s"
10331054
% (
10341055
exp_dir1,
10351056
sr2,
@@ -1044,12 +1065,13 @@ def click_train(
10441065
1 if if_cache_gpu17 == True else 0,
10451066
1 if if_save_every_weights18 == True else 0,
10461067
version19,
1068+
log_interval,
10471069
)
10481070
)
10491071
else:
10501072
cmd = (
10511073
config.python_cmd
1052-
+ " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s"
1074+
+ " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s -li %s"
10531075
% (
10541076
exp_dir1,
10551077
sr2,
@@ -1063,6 +1085,7 @@ def click_train(
10631085
1 if if_cache_gpu17 == True else 0,
10641086
1 if if_save_every_weights18 == True else 0,
10651087
version19,
1088+
log_interval,
10661089
)
10671090
)
10681091
print(cmd)

train/utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,9 @@ def get_hparams(init=True):
352352
required=True,
353353
help="if caching the dataset in GPU memory, 1 or 0",
354354
)
355+
parser.add_argument(
356+
"-li", "--log_interval", type=int, required=True, help="log interval"
357+
)
355358

356359
args = parser.parse_args()
357360
name = args.experiment_dir
@@ -391,6 +394,16 @@ def get_hparams(init=True):
391394
hparams.save_every_weights = args.save_every_weights
392395
hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu
393396
hparams.data.training_files = "%s/filelist.txt" % experiment_dir
397+
398+
hparams.train.log_interval = args.log_interval
399+
400+
# Update log_interval in the 'train' section of the config dictionary
401+
config["train"]["log_interval"] = args.log_interval
402+
403+
# Save the updated config back to the config_save_path
404+
with open(config_save_path, "w") as f:
405+
json.dump(config, f, indent=4)
406+
394407
return hparams
395408

396409

0 commit comments

Comments
 (0)