|
1 |
| -import glob |
2 | 1 | import json
|
3 |
| -import os |
4 |
| -import shutil |
5 | 2 |
|
6 | 3 | import torch
|
7 | 4 | from trainer.io import get_last_checkpoint
|
8 | 5 |
|
9 |
| -from tests import get_device_id, get_tests_output_path, run_cli |
| 6 | +from tests import get_device_id, run_cli |
10 | 7 | from TTS.tts.configs.neuralhmm_tts_config import NeuralhmmTTSConfig
|
11 | 8 |
|
12 |
| -config_path = os.path.join(get_tests_output_path(), "test_model_config.json") |
13 |
| -output_path = os.path.join(get_tests_output_path(), "train_outputs") |
14 |
| -parameter_path = os.path.join(get_tests_output_path(), "lj_parameters.pt") |
15 | 9 |
|
16 |
| -torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path) |
| 10 | +def test_train(tmp_path): |
| 11 | + config_path = tmp_path / "test_model_config.json" |
| 12 | + output_path = tmp_path / "train_outputs" |
| 13 | + parameter_path = tmp_path / "lj_parameters.pt" |
17 | 14 |
|
18 |
| -config = NeuralhmmTTSConfig( |
19 |
| - batch_size=3, |
20 |
| - eval_batch_size=3, |
21 |
| - num_loader_workers=0, |
22 |
| - num_eval_loader_workers=0, |
23 |
| - text_cleaner="phoneme_cleaners", |
24 |
| - use_phonemes=True, |
25 |
| - phoneme_language="en-us", |
26 |
| - phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"), |
27 |
| - run_eval=True, |
28 |
| - test_delay_epochs=-1, |
29 |
| - mel_statistics_parameter_path=parameter_path, |
30 |
| - epochs=1, |
31 |
| - print_step=1, |
32 |
| - test_sentences=[ |
33 |
| - "Be a voice, not an echo.", |
34 |
| - ], |
35 |
| - print_eval=True, |
36 |
| - max_sampling_time=50, |
37 |
| -) |
38 |
| -config.audio.do_trim_silence = True |
39 |
| -config.audio.trim_db = 60 |
40 |
| -config.save_json(config_path) |
| 15 | + torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path) |
41 | 16 |
|
| 17 | + config = NeuralhmmTTSConfig( |
| 18 | + batch_size=3, |
| 19 | + eval_batch_size=3, |
| 20 | + num_loader_workers=0, |
| 21 | + num_eval_loader_workers=0, |
| 22 | + text_cleaner="phoneme_cleaners", |
| 23 | + use_phonemes=True, |
| 24 | + phoneme_language="en-us", |
| 25 | + phoneme_cache_path=output_path / "phoneme_cache", |
| 26 | + run_eval=True, |
| 27 | + test_delay_epochs=-1, |
| 28 | + mel_statistics_parameter_path=parameter_path, |
| 29 | + epochs=1, |
| 30 | + print_step=1, |
| 31 | + test_sentences=[ |
| 32 | + "Be a voice, not an echo.", |
| 33 | + ], |
| 34 | + print_eval=True, |
| 35 | + max_sampling_time=50, |
| 36 | + ) |
| 37 | + config.audio.do_trim_silence = True |
| 38 | + config.audio.trim_db = 60 |
| 39 | + config.save_json(config_path) |
42 | 40 |
|
43 |
| -# train the model for one epoch when mel parameters exists |
44 |
| -command_train = ( |
45 |
| - f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " |
46 |
| - f"--coqpit.output_path {output_path} " |
47 |
| - "--coqpit.datasets.0.formatter ljspeech " |
48 |
| - "--coqpit.datasets.0.meta_file_train metadata.csv " |
49 |
| - "--coqpit.datasets.0.meta_file_val metadata.csv " |
50 |
| - "--coqpit.datasets.0.path tests/data/ljspeech " |
51 |
| - "--coqpit.test_delay_epochs 0 " |
52 |
| -) |
53 |
| -run_cli(command_train) |
| 41 | + # train the model for one epoch when mel parameters exists |
| 42 | + command_train = ( |
| 43 | + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " |
| 44 | + f"--coqpit.output_path {output_path} " |
| 45 | + "--coqpit.datasets.0.formatter ljspeech " |
| 46 | + "--coqpit.datasets.0.meta_file_train metadata.csv " |
| 47 | + "--coqpit.datasets.0.meta_file_val metadata.csv " |
| 48 | + "--coqpit.datasets.0.path tests/data/ljspeech " |
| 49 | + "--coqpit.test_delay_epochs 0 " |
| 50 | + ) |
| 51 | + run_cli(command_train) |
54 | 52 |
|
| 53 | + # train the model for one epoch when mel parameters have to be computed from the dataset |
| 54 | + if parameter_path.is_file(): |
| 55 | + parameter_path.unlink() |
| 56 | + command_train = ( |
| 57 | + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " |
| 58 | + f"--coqpit.output_path {output_path} " |
| 59 | + "--coqpit.datasets.0.formatter ljspeech " |
| 60 | + "--coqpit.datasets.0.meta_file_train metadata.csv " |
| 61 | + "--coqpit.datasets.0.meta_file_val metadata.csv " |
| 62 | + "--coqpit.datasets.0.path tests/data/ljspeech " |
| 63 | + "--coqpit.test_delay_epochs 0 " |
| 64 | + ) |
| 65 | + run_cli(command_train) |
55 | 66 |
|
56 |
| -# train the model for one epoch when mel parameters have to be computed from the dataset |
57 |
| -if os.path.exists(parameter_path): |
58 |
| - os.remove(parameter_path) |
59 |
| -command_train = ( |
60 |
| - f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " |
61 |
| - f"--coqpit.output_path {output_path} " |
62 |
| - "--coqpit.datasets.0.formatter ljspeech " |
63 |
| - "--coqpit.datasets.0.meta_file_train metadata.csv " |
64 |
| - "--coqpit.datasets.0.meta_file_val metadata.csv " |
65 |
| - "--coqpit.datasets.0.path tests/data/ljspeech " |
66 |
| - "--coqpit.test_delay_epochs 0 " |
67 |
| -) |
68 |
| -run_cli(command_train) |
| 67 | + # Find latest folder |
| 68 | + continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime) |
69 | 69 |
|
70 |
| -# Find latest folder |
71 |
| -continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) |
| 70 | + # Inference using TTS API |
| 71 | + continue_config_path = continue_path / "config.json" |
| 72 | + continue_restore_path, _ = get_last_checkpoint(continue_path) |
| 73 | + out_wav_path = tmp_path / "output.wav" |
72 | 74 |
|
73 |
| -# Inference using TTS API |
74 |
| -continue_config_path = os.path.join(continue_path, "config.json") |
75 |
| -continue_restore_path, _ = get_last_checkpoint(continue_path) |
76 |
| -out_wav_path = os.path.join(get_tests_output_path(), "output.wav") |
| 75 | + # Check integrity of the config |
| 76 | + with continue_config_path.open() as f: |
| 77 | + config_loaded = json.load(f) |
| 78 | + assert config_loaded["characters"] is not None |
| 79 | + assert config_loaded["output_path"] in str(continue_path) |
| 80 | + assert config_loaded["test_delay_epochs"] == 0 |
77 | 81 |
|
78 |
| -# Check integrity of the config |
79 |
| -with open(continue_config_path, "r", encoding="utf-8") as f: |
80 |
| - config_loaded = json.load(f) |
81 |
| -assert config_loaded["characters"] is not None |
82 |
| -assert config_loaded["output_path"] in continue_path |
83 |
| -assert config_loaded["test_delay_epochs"] == 0 |
| 82 | + # Load the model and run inference |
| 83 | + inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" |
| 84 | + run_cli(inference_command) |
84 | 85 |
|
85 |
| -# Load the model and run inference |
86 |
| -inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" |
87 |
| -run_cli(inference_command) |
88 |
| - |
89 |
| -# restore the model and continue training for one more epoch |
90 |
| -command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " |
91 |
| -run_cli(command_train) |
92 |
| -shutil.rmtree(continue_path) |
| 86 | + # restore the model and continue training for one more epoch |
| 87 | + command_train = ( |
| 88 | + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " |
| 89 | + ) |
| 90 | + run_cli(command_train) |
0 commit comments