From ce202532cfe74e2e297e4109a80a3b125f54bd49 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 2 Dec 2024 16:54:11 +0100 Subject: [PATCH 1/2] fix(xtts): clearer error message when file given to checkpoint_dir --- TTS/tts/models/xtts.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 35de91e359..d780e2b323 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass from pathlib import Path +from typing import Optional import librosa import torch @@ -10,6 +11,7 @@ from coqpit import Coqpit from trainer.io import load_fsspec +from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support @@ -719,14 +721,14 @@ def get_compatible_checkpoint_state_dict(self, model_path): def load_checkpoint( self, - config, - checkpoint_dir=None, - checkpoint_path=None, - vocab_path=None, - eval=True, - strict=True, - use_deepspeed=False, - speaker_file_path=None, + config: XttsConfig, + checkpoint_dir: Optional[str] = None, + checkpoint_path: Optional[str] = None, + vocab_path: Optional[str] = None, + eval: bool = True, + strict: bool = True, + use_deepspeed: bool = False, + speaker_file_path: Optional[str] = None, ): """ Loads a checkpoint from disk and initializes the model's state and tokenizer. @@ -742,7 +744,9 @@ def load_checkpoint( Returns: None """ - + if checkpoint_dir is not None and Path(checkpoint_dir).is_file(): + msg = f"You passed a file to `checkpoint_dir=`. Use `checkpoint_path={checkpoint_dir}` instead." + raise ValueError(msg) model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") if vocab_path is None: if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file(): From fe14ca6b68f8757f581ec04d2d0becddd7031d05 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 5 Dec 2024 15:38:50 +0100 Subject: [PATCH 2/2] refactor(xtts): remove duplicate xtts audio config --- TTS/demos/xtts_ft_demo/utils/gpt_train.py | 3 ++- TTS/tts/layers/xtts/trainer/gpt_trainer.py | 7 +------ TTS/tts/models/xtts.py | 5 +++-- recipes/ljspeech/xtts_v1/train_gpt_xtts.py | 3 ++- recipes/ljspeech/xtts_v2/train_gpt_xtts.py | 3 ++- tests/xtts_tests/test_xtts_gpt_train.py | 3 ++- tests/xtts_tests/test_xtts_v2-0_gpt_train.py | 3 ++- 7 files changed, 14 insertions(+), 13 deletions(-) diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py index f838297af3..411a9b0dbe 100644 --- a/TTS/demos/xtts_ft_demo/utils/gpt_train.py +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -5,7 +5,8 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig from TTS.utils.manage import ModelManager diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 0253d65ddd..107054189c 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -18,7 +18,7 @@ from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig +from TTS.tts.models.xtts import Xtts, XttsArgs from TTS.utils.generic_utils import is_pytorch_at_least_2_4 logger = logging.getLogger(__name__) @@ -34,11 +34,6 @@ class GPTTrainerConfig(XttsConfig): test_sentences: List[dict] = field(default_factory=lambda: []) -@dataclass -class XttsAudioConfig(XttsAudioConfig): - dvae_sample_rate: int = 22050 - - @dataclass class GPTArgs(XttsArgs): min_conditioning_length: int = 66150 diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index d780e2b323..f05863ae1d 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -11,7 +11,6 @@ from coqpit import Coqpit from trainer.io import load_fsspec -from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support @@ -103,10 +102,12 @@ class XttsAudioConfig(Coqpit): Args: sample_rate (int): The sample rate in which the GPT operates. output_sample_rate (int): The sample rate of the output audio waveform. + dvae_sample_rate (int): The sample rate of the DVAE """ sample_rate: int = 22050 output_sample_rate: int = 24000 + dvae_sample_rate: int = 22050 @dataclass @@ -721,7 +722,7 @@ def get_compatible_checkpoint_state_dict(self, model_path): def load_checkpoint( self, - config: XttsConfig, + config: "XttsConfig", checkpoint_dir: Optional[str] = None, checkpoint_path: Optional[str] = None, vocab_path: Optional[str] = None, diff --git a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py index d31ec8f1ed..a077a18064 100644 --- a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py @@ -4,7 +4,8 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig from TTS.utils.manage import ModelManager # Logging parameters diff --git a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py index ccaa97f1e4..362f45008e 100644 --- a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py @@ -4,7 +4,8 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig from TTS.utils.manage import ModelManager # Logging parameters diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/xtts_tests/test_xtts_gpt_train.py index b8b9a4e388..bb592f1f2d 100644 --- a/tests/xtts_tests/test_xtts_gpt_train.py +++ b/tests/xtts_tests/test_xtts_gpt_train.py @@ -8,7 +8,8 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.layers.xtts.dvae import DiscreteVAE -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig config_dataset = BaseDatasetConfig( formatter="ljspeech", diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py index 6663433c12..454e867385 100644 --- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py +++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py @@ -8,7 +8,8 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.layers.xtts.dvae import DiscreteVAE -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig config_dataset = BaseDatasetConfig( formatter="ljspeech",