Skip to content

Commit f01b6f3

Browse files
committed
test(compute_statistics): convert to python
1 parent c66fab4 commit f01b6f3

File tree

4 files changed

+22
-13
lines changed

4 files changed

+22
-13
lines changed

TTS/bin/compute_statistics.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import glob
66
import logging
77
import os
8+
from typing import Optional
89

910
import numpy as np
1011
from tqdm import tqdm
@@ -16,10 +17,7 @@
1617
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
1718

1819

19-
def main():
20-
"""Run preprocessing process."""
21-
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
22-
20+
def parse_args(arg_list: Optional[list[str]]) -> tuple[argparse.Namespace, list[str]]:
2321
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
2422
parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.")
2523
parser.add_argument("out_path", type=str, help="save path (directory and filename).")
@@ -29,7 +27,13 @@ def main():
2927
required=False,
3028
help="folder including the target set of wavs overriding dataset config.",
3129
)
32-
args, overrides = parser.parse_known_args()
30+
return parser.parse_known_args(arg_list)
31+
32+
33+
def main(arg_list: Optional[list[str]] = None):
34+
"""Run preprocessing process."""
35+
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
36+
args, overrides = parse_args(arg_list)
3337

3438
CONFIG = load_config(args.config_path)
3539
CONFIG.parse_known_args(overrides, relaxed_parser=True)

run_bash_tests.sh

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@ TF_CPP_MIN_LOG_LEVEL=3
33

44
# runtime bash based tests
55
# TODO: move these to python
6-
./tests/bash_tests/test_demo_server.sh && \
7-
./tests/bash_tests/test_compute_statistics.sh
6+
./tests/bash_tests/test_demo_server.sh
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from pathlib import Path
2+
3+
from tests import get_tests_input_path
4+
from TTS.bin.compute_statistics import main
5+
6+
7+
def test_compute_statistics(tmp_path):
8+
config_path = Path(get_tests_input_path()) / "test_glow_tts_config.json"
9+
output_path = tmp_path / "scale_stats.npy"
10+
11+
args = ["--config_path", str(config_path), "--out_path", str(output_path)]
12+
main(args)

tests/bash_tests/test_compute_statistics.sh

-6
This file was deleted.

0 commit comments

Comments
 (0)