Skip to content

Commit fae4229

Browse files
committed
test(xtts): use temp folder and pytest.parametrize
1 parent 0c6f6c9 commit fae4229

13 files changed

+880
-1070
lines changed
+73-75
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,90 @@
1-
import glob
21
import json
3-
import os
4-
import shutil
52

63
import torch
74
from trainer.io import get_last_checkpoint
85

9-
from tests import get_device_id, get_tests_output_path, run_cli
6+
from tests import get_device_id, run_cli
107
from TTS.tts.configs.neuralhmm_tts_config import NeuralhmmTTSConfig
118

12-
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
13-
output_path = os.path.join(get_tests_output_path(), "train_outputs")
14-
parameter_path = os.path.join(get_tests_output_path(), "lj_parameters.pt")
159

16-
torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path)
10+
def test_train(tmp_path):
11+
config_path = tmp_path / "test_model_config.json"
12+
output_path = tmp_path / "train_outputs"
13+
parameter_path = tmp_path / "lj_parameters.pt"
1714

18-
config = NeuralhmmTTSConfig(
19-
batch_size=3,
20-
eval_batch_size=3,
21-
num_loader_workers=0,
22-
num_eval_loader_workers=0,
23-
text_cleaner="phoneme_cleaners",
24-
use_phonemes=True,
25-
phoneme_language="en-us",
26-
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
27-
run_eval=True,
28-
test_delay_epochs=-1,
29-
mel_statistics_parameter_path=parameter_path,
30-
epochs=1,
31-
print_step=1,
32-
test_sentences=[
33-
"Be a voice, not an echo.",
34-
],
35-
print_eval=True,
36-
max_sampling_time=50,
37-
)
38-
config.audio.do_trim_silence = True
39-
config.audio.trim_db = 60
40-
config.save_json(config_path)
15+
torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path)
4116

17+
config = NeuralhmmTTSConfig(
18+
batch_size=3,
19+
eval_batch_size=3,
20+
num_loader_workers=0,
21+
num_eval_loader_workers=0,
22+
text_cleaner="phoneme_cleaners",
23+
use_phonemes=True,
24+
phoneme_language="en-us",
25+
phoneme_cache_path=output_path / "phoneme_cache",
26+
run_eval=True,
27+
test_delay_epochs=-1,
28+
mel_statistics_parameter_path=parameter_path,
29+
epochs=1,
30+
print_step=1,
31+
test_sentences=[
32+
"Be a voice, not an echo.",
33+
],
34+
print_eval=True,
35+
max_sampling_time=50,
36+
)
37+
config.audio.do_trim_silence = True
38+
config.audio.trim_db = 60
39+
config.save_json(config_path)
4240

43-
# train the model for one epoch when mel parameters exists
44-
command_train = (
45-
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
46-
f"--coqpit.output_path {output_path} "
47-
"--coqpit.datasets.0.formatter ljspeech "
48-
"--coqpit.datasets.0.meta_file_train metadata.csv "
49-
"--coqpit.datasets.0.meta_file_val metadata.csv "
50-
"--coqpit.datasets.0.path tests/data/ljspeech "
51-
"--coqpit.test_delay_epochs 0 "
52-
)
53-
run_cli(command_train)
41+
# train the model for one epoch when mel parameters exists
42+
command_train = (
43+
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
44+
f"--coqpit.output_path {output_path} "
45+
"--coqpit.datasets.0.formatter ljspeech "
46+
"--coqpit.datasets.0.meta_file_train metadata.csv "
47+
"--coqpit.datasets.0.meta_file_val metadata.csv "
48+
"--coqpit.datasets.0.path tests/data/ljspeech "
49+
"--coqpit.test_delay_epochs 0 "
50+
)
51+
run_cli(command_train)
5452

53+
# train the model for one epoch when mel parameters have to be computed from the dataset
54+
if parameter_path.is_file():
55+
parameter_path.unlink()
56+
command_train = (
57+
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
58+
f"--coqpit.output_path {output_path} "
59+
"--coqpit.datasets.0.formatter ljspeech "
60+
"--coqpit.datasets.0.meta_file_train metadata.csv "
61+
"--coqpit.datasets.0.meta_file_val metadata.csv "
62+
"--coqpit.datasets.0.path tests/data/ljspeech "
63+
"--coqpit.test_delay_epochs 0 "
64+
)
65+
run_cli(command_train)
5566

56-
# train the model for one epoch when mel parameters have to be computed from the dataset
57-
if os.path.exists(parameter_path):
58-
os.remove(parameter_path)
59-
command_train = (
60-
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
61-
f"--coqpit.output_path {output_path} "
62-
"--coqpit.datasets.0.formatter ljspeech "
63-
"--coqpit.datasets.0.meta_file_train metadata.csv "
64-
"--coqpit.datasets.0.meta_file_val metadata.csv "
65-
"--coqpit.datasets.0.path tests/data/ljspeech "
66-
"--coqpit.test_delay_epochs 0 "
67-
)
68-
run_cli(command_train)
67+
# Find latest folder
68+
continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
6969

70-
# Find latest folder
71-
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
70+
# Inference using TTS API
71+
continue_config_path = continue_path / "config.json"
72+
continue_restore_path, _ = get_last_checkpoint(continue_path)
73+
out_wav_path = tmp_path / "output.wav"
7274

73-
# Inference using TTS API
74-
continue_config_path = os.path.join(continue_path, "config.json")
75-
continue_restore_path, _ = get_last_checkpoint(continue_path)
76-
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
75+
# Check integrity of the config
76+
with continue_config_path.open() as f:
77+
config_loaded = json.load(f)
78+
assert config_loaded["characters"] is not None
79+
assert config_loaded["output_path"] in str(continue_path)
80+
assert config_loaded["test_delay_epochs"] == 0
7781

78-
# Check integrity of the config
79-
with open(continue_config_path, "r", encoding="utf-8") as f:
80-
config_loaded = json.load(f)
81-
assert config_loaded["characters"] is not None
82-
assert config_loaded["output_path"] in continue_path
83-
assert config_loaded["test_delay_epochs"] == 0
82+
# Load the model and run inference
83+
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
84+
run_cli(inference_command)
8485

85-
# Load the model and run inference
86-
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
87-
run_cli(inference_command)
88-
89-
# restore the model and continue training for one more epoch
90-
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
91-
run_cli(command_train)
92-
shutil.rmtree(continue_path)
86+
# restore the model and continue training for one more epoch
87+
command_train = (
88+
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
89+
)
90+
run_cli(command_train)
+73-75
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,90 @@
1-
import glob
21
import json
3-
import os
4-
import shutil
52

63
import torch
74
from trainer.io import get_last_checkpoint
85

9-
from tests import get_device_id, get_tests_output_path, run_cli
6+
from tests import get_device_id, run_cli
107
from TTS.tts.configs.overflow_config import OverflowConfig
118

12-
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
13-
output_path = os.path.join(get_tests_output_path(), "train_outputs")
14-
parameter_path = os.path.join(get_tests_output_path(), "lj_parameters.pt")
159

16-
torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path)
10+
def test_train(tmp_path):
11+
config_path = tmp_path / "test_model_config.json"
12+
output_path = tmp_path / "train_outputs"
13+
parameter_path = tmp_path / "lj_parameters.pt"
1714

18-
config = OverflowConfig(
19-
batch_size=3,
20-
eval_batch_size=3,
21-
num_loader_workers=0,
22-
num_eval_loader_workers=0,
23-
text_cleaner="phoneme_cleaners",
24-
use_phonemes=True,
25-
phoneme_language="en-us",
26-
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
27-
run_eval=True,
28-
test_delay_epochs=-1,
29-
mel_statistics_parameter_path=parameter_path,
30-
epochs=1,
31-
print_step=1,
32-
test_sentences=[
33-
"Be a voice, not an echo.",
34-
],
35-
print_eval=True,
36-
max_sampling_time=50,
37-
)
38-
config.audio.do_trim_silence = True
39-
config.audio.trim_db = 60
40-
config.save_json(config_path)
15+
torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path)
4116

17+
config = OverflowConfig(
18+
batch_size=3,
19+
eval_batch_size=3,
20+
num_loader_workers=0,
21+
num_eval_loader_workers=0,
22+
text_cleaner="phoneme_cleaners",
23+
use_phonemes=True,
24+
phoneme_language="en-us",
25+
phoneme_cache_path=output_path / "phoneme_cache",
26+
run_eval=True,
27+
test_delay_epochs=-1,
28+
mel_statistics_parameter_path=parameter_path,
29+
epochs=1,
30+
print_step=1,
31+
test_sentences=[
32+
"Be a voice, not an echo.",
33+
],
34+
print_eval=True,
35+
max_sampling_time=50,
36+
)
37+
config.audio.do_trim_silence = True
38+
config.audio.trim_db = 60
39+
config.save_json(config_path)
4240

43-
# train the model for one epoch when mel parameters exists
44-
command_train = (
45-
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
46-
f"--coqpit.output_path {output_path} "
47-
"--coqpit.datasets.0.formatter ljspeech "
48-
"--coqpit.datasets.0.meta_file_train metadata.csv "
49-
"--coqpit.datasets.0.meta_file_val metadata.csv "
50-
"--coqpit.datasets.0.path tests/data/ljspeech "
51-
"--coqpit.test_delay_epochs 0 "
52-
)
53-
run_cli(command_train)
41+
# train the model for one epoch when mel parameters exists
42+
command_train = (
43+
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
44+
f"--coqpit.output_path {output_path} "
45+
"--coqpit.datasets.0.formatter ljspeech "
46+
"--coqpit.datasets.0.meta_file_train metadata.csv "
47+
"--coqpit.datasets.0.meta_file_val metadata.csv "
48+
"--coqpit.datasets.0.path tests/data/ljspeech "
49+
"--coqpit.test_delay_epochs 0 "
50+
)
51+
run_cli(command_train)
5452

53+
# train the model for one epoch when mel parameters have to be computed from the dataset
54+
if parameter_path.is_file():
55+
parameter_path.unlink()
56+
command_train = (
57+
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
58+
f"--coqpit.output_path {output_path} "
59+
"--coqpit.datasets.0.formatter ljspeech "
60+
"--coqpit.datasets.0.meta_file_train metadata.csv "
61+
"--coqpit.datasets.0.meta_file_val metadata.csv "
62+
"--coqpit.datasets.0.path tests/data/ljspeech "
63+
"--coqpit.test_delay_epochs 0 "
64+
)
65+
run_cli(command_train)
5566

56-
# train the model for one epoch when mel parameters have to be computed from the dataset
57-
if os.path.exists(parameter_path):
58-
os.remove(parameter_path)
59-
command_train = (
60-
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
61-
f"--coqpit.output_path {output_path} "
62-
"--coqpit.datasets.0.formatter ljspeech "
63-
"--coqpit.datasets.0.meta_file_train metadata.csv "
64-
"--coqpit.datasets.0.meta_file_val metadata.csv "
65-
"--coqpit.datasets.0.path tests/data/ljspeech "
66-
"--coqpit.test_delay_epochs 0 "
67-
)
68-
run_cli(command_train)
67+
# Find latest folder
68+
continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
6969

70-
# Find latest folder
71-
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
70+
# Inference using TTS API
71+
continue_config_path = continue_path / "config.json"
72+
continue_restore_path, _ = get_last_checkpoint(continue_path)
73+
out_wav_path = tmp_path / "output.wav"
7274

73-
# Inference using TTS API
74-
continue_config_path = os.path.join(continue_path, "config.json")
75-
continue_restore_path, _ = get_last_checkpoint(continue_path)
76-
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
75+
# Check integrity of the config
76+
with continue_config_path.open() as f:
77+
config_loaded = json.load(f)
78+
assert config_loaded["characters"] is not None
79+
assert config_loaded["output_path"] in str(continue_path)
80+
assert config_loaded["test_delay_epochs"] == 0
7781

78-
# Check integrity of the config
79-
with open(continue_config_path, "r", encoding="utf-8") as f:
80-
config_loaded = json.load(f)
81-
assert config_loaded["characters"] is not None
82-
assert config_loaded["output_path"] in continue_path
83-
assert config_loaded["test_delay_epochs"] == 0
82+
# Load the model and run inference
83+
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
84+
run_cli(inference_command)
8485

85-
# Load the model and run inference
86-
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
87-
run_cli(inference_command)
88-
89-
# restore the model and continue training for one more epoch
90-
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
91-
run_cli(command_train)
92-
shutil.rmtree(continue_path)
86+
# restore the model and continue training for one more epoch
87+
command_train = (
88+
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
89+
)
90+
run_cli(command_train)

0 commit comments

Comments
 (0)