|
| 1 | +import glob |
| 2 | +import os |
| 3 | + |
| 4 | +import pytest |
| 5 | + |
| 6 | +from tests import run_main |
| 7 | +from TTS.bin.train_vocoder import main |
| 8 | +from TTS.vocoder.configs import ( |
| 9 | + FullbandMelganConfig, |
| 10 | + HifiganConfig, |
| 11 | + MelganConfig, |
| 12 | + MultibandMelganConfig, |
| 13 | + ParallelWaveganConfig, |
| 14 | + WavegradConfig, |
| 15 | + WavernnConfig, |
| 16 | +) |
| 17 | +from TTS.vocoder.models.wavernn import WavernnArgs |
| 18 | + |
| 19 | +GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" |
| 20 | + |
| 21 | +BASE_CONFIG = { |
| 22 | + "batch_size": 8, |
| 23 | + "eval_batch_size": 8, |
| 24 | + "num_loader_workers": 0, |
| 25 | + "num_eval_loader_workers": 0, |
| 26 | + "run_eval": True, |
| 27 | + "test_delay_epochs": -1, |
| 28 | + "epochs": 1, |
| 29 | + "seq_len": 8192, |
| 30 | + "eval_split_size": 1, |
| 31 | + "print_step": 1, |
| 32 | + "print_eval": True, |
| 33 | + "data_path": "tests/data/ljspeech", |
| 34 | +} |
| 35 | + |
| 36 | +DISCRIMINATOR_MODEL_PARAMS = { |
| 37 | + "base_channels": 16, |
| 38 | + "max_channels": 64, |
| 39 | + "downsample_factors": [4, 4, 4], |
| 40 | +} |
| 41 | + |
| 42 | + |
| 43 | +def create_config(config_class, **overrides): |
| 44 | + params = {**BASE_CONFIG, **overrides} |
| 45 | + return config_class(**params) |
| 46 | + |
| 47 | + |
| 48 | +def run_train(tmp_path, config): |
| 49 | + config_path = str(tmp_path / "test_vocoder_config.json") |
| 50 | + output_path = tmp_path / "train_outputs" |
| 51 | + config.output_path = output_path |
| 52 | + config.audio.do_trim_silence = True |
| 53 | + config.audio.trim_db = 60 |
| 54 | + config.save_json(config_path) |
| 55 | + |
| 56 | + # Train the model for one epoch |
| 57 | + run_main(main, ["--config_path", config_path]) |
| 58 | + |
| 59 | + # Find the latest folder |
| 60 | + continue_path = str(max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)) |
| 61 | + |
| 62 | + # Restore the model and continue training for one more epoch |
| 63 | + run_main(main, ["--continue_path", continue_path]) |
| 64 | + |
| 65 | + |
| 66 | +def test_train_hifigan(tmp_path): |
| 67 | + config = create_config(HifiganConfig, seq_len=1024) |
| 68 | + run_train(tmp_path, config) |
| 69 | + |
| 70 | + |
| 71 | +def test_train_melgan(tmp_path): |
| 72 | + config = create_config( |
| 73 | + MelganConfig, |
| 74 | + batch_size=4, |
| 75 | + eval_batch_size=4, |
| 76 | + seq_len=2048, |
| 77 | + discriminator_model_params=DISCRIMINATOR_MODEL_PARAMS, |
| 78 | + ) |
| 79 | + run_train(tmp_path, config) |
| 80 | + |
| 81 | + |
| 82 | +def test_train_multiband_melgan(tmp_path): |
| 83 | + config = create_config( |
| 84 | + MultibandMelganConfig, steps_to_start_discriminator=1, discriminator_model_params=DISCRIMINATOR_MODEL_PARAMS |
| 85 | + ) |
| 86 | + run_train(tmp_path, config) |
| 87 | + |
| 88 | + |
| 89 | +def test_train_fullband_melgan(tmp_path): |
| 90 | + config = create_config(FullbandMelganConfig, discriminator_model_params=DISCRIMINATOR_MODEL_PARAMS) |
| 91 | + run_train(tmp_path, config) |
| 92 | + |
| 93 | + |
| 94 | +def test_train_parallel_wavegan(tmp_path): |
| 95 | + config = create_config(ParallelWaveganConfig, batch_size=4, eval_batch_size=4, seq_len=2048) |
| 96 | + run_train(tmp_path, config) |
| 97 | + |
| 98 | + |
| 99 | +# TODO: Reactivate after improving CI run times |
| 100 | +@pytest.mark.skipif(GITHUB_ACTIONS, reason="Takes ~2h on CI (15min/step vs 8sec/step locally)") |
| 101 | +def test_train_wavegrad(tmp_path): |
| 102 | + config = create_config(WavegradConfig, test_noise_schedule={"min_val": 1e-6, "max_val": 1e-2, "num_steps": 2}) |
| 103 | + run_train(tmp_path, config) |
| 104 | + |
| 105 | + |
| 106 | +def test_train_wavernn(tmp_path): |
| 107 | + config = create_config( |
| 108 | + WavernnConfig, |
| 109 | + model_args=WavernnArgs(), |
| 110 | + seq_len=256, # For shorter test time |
| 111 | + ) |
| 112 | + run_train(tmp_path, config) |
0 commit comments