Skip to content

Commit 162f527

Browse files
committed
test: move vocoder training tests into one file
Gets rid of all the duplication in the tests
1 parent b71669a commit 162f527

9 files changed

+117
-306
lines changed

TTS/bin/train_vocoder.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import logging
22
import os
3+
import sys
34
from dataclasses import dataclass, field
5+
from typing import Optional
46

57
from trainer import Trainer, TrainerArgs
68

@@ -16,7 +18,7 @@ class TrainVocoderArgs(TrainerArgs):
1618
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
1719

1820

19-
def main():
21+
def main(arg_list: Optional[list[str]] = None):
2022
"""Run `tts` model training directly by a `config.json` file."""
2123
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
2224

@@ -25,7 +27,7 @@ def main():
2527
parser = train_args.init_argparse(arg_prefix="")
2628

2729
# override trainer args from comman-line args
28-
args, config_overrides = parser.parse_known_args()
30+
args, config_overrides = parser.parse_known_args(arg_list)
2931
train_args.parse_args(args)
3032

3133
# load config.json and register
@@ -75,6 +77,7 @@ def main():
7577
parse_command_line_args=False,
7678
)
7779
trainer.fit()
80+
sys.exit(0)
7881

7982

8083
if __name__ == "__main__":

tests/vocoder_tests/test_fullband_melgan_train.py

-42
This file was deleted.

tests/vocoder_tests/test_hifigan_train.py

-41
This file was deleted.

tests/vocoder_tests/test_melgan_train.py

-42
This file was deleted.

tests/vocoder_tests/test_multiband_melgan_train.py

-43
This file was deleted.

tests/vocoder_tests/test_parallel_wavegan_train.py

-43
This file was deleted.

tests/vocoder_tests/test_training.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import glob
2+
import os
3+
4+
import pytest
5+
6+
from tests import run_main
7+
from TTS.bin.train_vocoder import main
8+
from TTS.vocoder.configs import (
9+
FullbandMelganConfig,
10+
HifiganConfig,
11+
MelganConfig,
12+
MultibandMelganConfig,
13+
ParallelWaveganConfig,
14+
WavegradConfig,
15+
WavernnConfig,
16+
)
17+
from TTS.vocoder.models.wavernn import WavernnArgs
18+
19+
GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
20+
21+
BASE_CONFIG = {
22+
"batch_size": 8,
23+
"eval_batch_size": 8,
24+
"num_loader_workers": 0,
25+
"num_eval_loader_workers": 0,
26+
"run_eval": True,
27+
"test_delay_epochs": -1,
28+
"epochs": 1,
29+
"seq_len": 8192,
30+
"eval_split_size": 1,
31+
"print_step": 1,
32+
"print_eval": True,
33+
"data_path": "tests/data/ljspeech",
34+
}
35+
36+
DISCRIMINATOR_MODEL_PARAMS = {
37+
"base_channels": 16,
38+
"max_channels": 64,
39+
"downsample_factors": [4, 4, 4],
40+
}
41+
42+
43+
def create_config(config_class, **overrides):
44+
params = {**BASE_CONFIG, **overrides}
45+
return config_class(**params)
46+
47+
48+
def run_train(tmp_path, config):
49+
config_path = str(tmp_path / "test_vocoder_config.json")
50+
output_path = tmp_path / "train_outputs"
51+
config.output_path = output_path
52+
config.audio.do_trim_silence = True
53+
config.audio.trim_db = 60
54+
config.save_json(config_path)
55+
56+
# Train the model for one epoch
57+
run_main(main, ["--config_path", config_path])
58+
59+
# Find the latest folder
60+
continue_path = str(max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime))
61+
62+
# Restore the model and continue training for one more epoch
63+
run_main(main, ["--continue_path", continue_path])
64+
65+
66+
def test_train_hifigan(tmp_path):
67+
config = create_config(HifiganConfig, seq_len=1024)
68+
run_train(tmp_path, config)
69+
70+
71+
def test_train_melgan(tmp_path):
72+
config = create_config(
73+
MelganConfig,
74+
batch_size=4,
75+
eval_batch_size=4,
76+
seq_len=2048,
77+
discriminator_model_params=DISCRIMINATOR_MODEL_PARAMS,
78+
)
79+
run_train(tmp_path, config)
80+
81+
82+
def test_train_multiband_melgan(tmp_path):
83+
config = create_config(
84+
MultibandMelganConfig, steps_to_start_discriminator=1, discriminator_model_params=DISCRIMINATOR_MODEL_PARAMS
85+
)
86+
run_train(tmp_path, config)
87+
88+
89+
def test_train_fullband_melgan(tmp_path):
90+
config = create_config(FullbandMelganConfig, discriminator_model_params=DISCRIMINATOR_MODEL_PARAMS)
91+
run_train(tmp_path, config)
92+
93+
94+
def test_train_parallel_wavegan(tmp_path):
95+
config = create_config(ParallelWaveganConfig, batch_size=4, eval_batch_size=4, seq_len=2048)
96+
run_train(tmp_path, config)
97+
98+
99+
# TODO: Reactivate after improving CI run times
100+
@pytest.mark.skipif(GITHUB_ACTIONS, reason="Takes ~2h on CI (15min/step vs 8sec/step locally)")
101+
def test_train_wavegrad(tmp_path):
102+
config = create_config(WavegradConfig, test_noise_schedule={"min_val": 1e-6, "max_val": 1e-2, "num_steps": 2})
103+
run_train(tmp_path, config)
104+
105+
106+
def test_train_wavernn(tmp_path):
107+
config = create_config(
108+
WavernnConfig,
109+
model_args=WavernnArgs(),
110+
seq_len=256, # For shorter test time
111+
)
112+
run_train(tmp_path, config)

0 commit comments

Comments
 (0)