From 7fe5e5d0d1dfae4f63965a35275efacf8eb54b60 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 14 Dec 2024 19:47:00 +0100 Subject: [PATCH] test: call cli tests via main functions to get test coverage --- TTS/bin/synthesize.py | 26 +++++++++++++----------- tests/__init__.py | 8 ++++++++ tests/inference_tests/test_synthesize.py | 23 ++++++++++----------- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 47b442e266..3ba5aec948 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -7,6 +7,7 @@ import logging import sys from argparse import RawTextHelpFormatter +from typing import Optional # pylint: disable=redefined-outer-name, unused-argument from TTS.utils.generic_utils import ConsoleFormatter, setup_logger @@ -134,7 +135,7 @@ """ -def parse_args() -> argparse.Namespace: +def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace: """Parse arguments.""" parser = argparse.ArgumentParser( description=description.replace(" ```\n", ""), @@ -290,7 +291,7 @@ def parse_args() -> argparse.Namespace: help="Voice dir for tortoise model", ) - args = parser.parse_args() + args = parser.parse_args(arg_list) # print the description if either text or list_models is not set check_args = [ @@ -309,10 +310,10 @@ def parse_args() -> argparse.Namespace: return args -def main() -> None: +def main(arg_list: Optional[list[str]] = None) -> None: """Entry point for `tts` command line interface.""" setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) - args = parse_args() + args = parse_args(arg_list) pipe_out = sys.stdout if args.pipe_out else None @@ -339,18 +340,18 @@ def main() -> None: # 1) List pre-trained TTS models if args.list_models: manager.list_models() - sys.exit() + sys.exit(0) # 2) Info about pre-trained TTS models (without loading a model) if args.model_info_by_idx: model_query = args.model_info_by_idx manager.model_info_by_idx(model_query) - sys.exit() + sys.exit(0) if args.model_info_by_name: model_query_full_name = args.model_info_by_name manager.model_info_by_full_name(model_query_full_name) - sys.exit() + sys.exit(0) # 3) Load a model for further info or TTS/VC device = args.device @@ -376,23 +377,23 @@ def main() -> None: if args.list_speaker_idxs: if not api.is_multi_speaker: logger.info("Model only has a single speaker.") - return + sys.exit(0) logger.info( "Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." ) logger.info(api.speakers) - return + sys.exit(0) # query langauge ids of a multi-lingual model. if args.list_language_idxs: if not api.is_multi_lingual: logger.info("Monolingual model.") - return + sys.exit(0) logger.info( "Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." ) logger.info(api.languages) - return + sys.exit(0) # check the arguments against a multi-speaker model. if api.is_multi_speaker and (not args.speaker_idx and not args.speaker_wav): @@ -400,7 +401,7 @@ def main() -> None: "Looks like you use a multi-speaker model. Define `--speaker_idx` to " "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`." ) - return + sys.exit(1) # RUN THE SYNTHESIS if args.text: @@ -429,6 +430,7 @@ def main() -> None: pipe_out=pipe_out, ) logger.info("Saved VC output to %s", args.out_path) + sys.exit(0) if __name__ == "__main__": diff --git a/tests/__init__.py b/tests/__init__.py index f0a8b2f118..8108bdeb50 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,7 @@ import os +from typing import Callable, Optional +import pytest from trainer.generic_utils import get_cuda from TTS.config import BaseDatasetConfig @@ -44,6 +46,12 @@ def run_cli(command): assert exit_status == 0, f" [!] command `{command}` failed." +def run_main(main_func: Callable, args: Optional[list[str]] = None, expected_code: int = 0): + with pytest.raises(SystemExit) as exc_info: + main_func(args) + assert exc_info.value.code == expected_code + + def get_test_data_config(): return BaseDatasetConfig(formatter="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv") diff --git a/tests/inference_tests/test_synthesize.py b/tests/inference_tests/test_synthesize.py index c49ea5ab43..beb7df689b 100644 --- a/tests/inference_tests/test_synthesize.py +++ b/tests/inference_tests/test_synthesize.py @@ -1,18 +1,17 @@ -from tests import run_cli +from tests import run_main +from TTS.bin.synthesize import main def test_synthesize(tmp_path): """Test synthesize.py with diffent arguments.""" - output_path = tmp_path / "output.wav" - run_cli("tts --list_models") + output_path = str(tmp_path / "output.wav") + + run_main(main, ["--list_models"]) # single speaker model - run_cli(f'tts --text "This is an example." --out_path "{output_path}"') - run_cli( - "tts --model_name tts_models/en/ljspeech/glow-tts " f'--text "This is an example." --out_path "{output_path}"' - ) - run_cli( - "tts --model_name tts_models/en/ljspeech/glow-tts " - "--vocoder_name vocoder_models/en/ljspeech/multiband-melgan " - f'--text "This is an example." --out_path "{output_path}"' - ) + args = ["--text", "This is an example.", "--out_path", output_path] + run_main(main, args) + + args = [*args, "--model_name", "tts_models/en/ljspeech/glow-tts"] + run_main(main, args) + run_main(main, [*args, "--vocoder_name", "vocoder_models/en/ljspeech/multiband-melgan"])