diff --git a/TTS/__init__.py b/TTS/__init__.py index 8e93c9b5db..7f2159225c 100644 --- a/TTS/__init__.py +++ b/TTS/__init__.py @@ -13,8 +13,7 @@ import torch from TTS.config.shared_configs import BaseDatasetConfig - from TTS.tts.configs.xtts_config import XttsConfig - from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig + from TTS.tts.configs.xtts_config import XttsArgs, XttsAudioConfig, XttsConfig from TTS.utils.radam import RAdam torch.serialization.add_safe_globals([dict, defaultdict, RAdam]) diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py index f838297af3..f4c390b128 100644 --- a/TTS/demos/xtts_ft_demo/utils/gpt_train.py +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -4,8 +4,9 @@ from trainer import Trainer, TrainerArgs from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.xtts_config import XttsAudioConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig from TTS.utils.manage import ModelManager diff --git a/TTS/tts/configs/xtts_config.py b/TTS/tts/configs/xtts_config.py index bbf048e1ab..7081d63742 100644 --- a/TTS/tts/configs/xtts_config.py +++ b/TTS/tts/configs/xtts_config.py @@ -1,8 +1,92 @@ from dataclasses import dataclass, field from typing import List +from coqpit import Coqpit + from TTS.tts.configs.shared_configs import BaseTTSConfig -from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig + + +@dataclass +class XttsAudioConfig(Coqpit): + """ + Configuration class for audio-related parameters in the XTTS model. + + Args: + sample_rate (int): The sample rate in which the GPT operates. + output_sample_rate (int): The sample rate of the output audio waveform. + dvae_sample_rate (int): The sample rate of the DVAE + """ + + sample_rate: int = 22050 + output_sample_rate: int = 24000 + dvae_sample_rate: int = 22050 + + +@dataclass +class XttsArgs(Coqpit): + """A dataclass to represent XTTS model arguments that define the model structure. + + Args: + gpt_batch_size (int): The size of the auto-regressive batch. + enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. + kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. + gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. + clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. + decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. + num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. + + For GPT model: + gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. + gpt_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402. + gpt_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70. + gpt_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30. + gpt_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024. + gpt_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16. + gpt_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255. + gpt_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255. + gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False. + gpt_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False. + gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024. + gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True. + gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False. + """ + + gpt_batch_size: int = 1 + enable_redaction: bool = False + kv_cache: bool = True + gpt_checkpoint: str = None + clvp_checkpoint: str = None + decoder_checkpoint: str = None + num_chars: int = 255 + + # XTTS GPT Encoder params + tokenizer_file: str = "" + gpt_max_audio_tokens: int = 605 + gpt_max_text_tokens: int = 402 + gpt_max_prompt_tokens: int = 70 + gpt_layers: int = 30 + gpt_n_model_channels: int = 1024 + gpt_n_heads: int = 16 + gpt_number_text_tokens: int = None + gpt_start_text_token: int = None + gpt_stop_text_token: int = None + gpt_num_audio_tokens: int = 8194 + gpt_start_audio_token: int = 8192 + gpt_stop_audio_token: int = 8193 + gpt_code_stride_len: int = 1024 + gpt_use_masking_gt_prompt_approach: bool = True + gpt_use_perceiver_resampler: bool = False + + # HifiGAN Decoder params + input_sample_rate: int = 22050 + output_sample_rate: int = 24000 + output_hop_length: int = 256 + decoder_input_dim: int = 1024 + d_vector_dim: int = 512 + cond_d_vector_in_each_upsampling_layer: bool = True + + # constants + duration_const: int = 102400 @dataclass diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 0253d65ddd..8b6bb09f5e 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -11,14 +11,14 @@ from trainer.torch import DistributedSampler from trainer.trainer_utils import get_optimizer, get_scheduler -from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.configs.xtts_config import XttsArgs, XttsConfig from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram from TTS.tts.layers.xtts.dvae import DiscreteVAE from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig +from TTS.tts.models.xtts import Xtts from TTS.utils.generic_utils import is_pytorch_at_least_2_4 logger = logging.getLogger(__name__) @@ -34,11 +34,6 @@ class GPTTrainerConfig(XttsConfig): test_sentences: List[dict] = field(default_factory=lambda: []) -@dataclass -class XttsAudioConfig(XttsAudioConfig): - dvae_sample_rate: int = 22050 - - @dataclass class GPTArgs(XttsArgs): min_conditioning_length: int = 66150 diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index d780e2b323..50455d2ef6 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -1,6 +1,5 @@ import logging import os -from dataclasses import dataclass from pathlib import Path from typing import Optional @@ -95,87 +94,6 @@ def load_audio(audiopath, sampling_rate): return audio -@dataclass -class XttsAudioConfig(Coqpit): - """ - Configuration class for audio-related parameters in the XTTS model. - - Args: - sample_rate (int): The sample rate in which the GPT operates. - output_sample_rate (int): The sample rate of the output audio waveform. - """ - - sample_rate: int = 22050 - output_sample_rate: int = 24000 - - -@dataclass -class XttsArgs(Coqpit): - """A dataclass to represent XTTS model arguments that define the model structure. - - Args: - gpt_batch_size (int): The size of the auto-regressive batch. - enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. - kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. - gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. - clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. - decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. - num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. - - For GPT model: - gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. - gpt_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402. - gpt_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70. - gpt_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30. - gpt_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024. - gpt_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16. - gpt_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255. - gpt_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255. - gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False. - gpt_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False. - gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024. - gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True. - gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False. - """ - - gpt_batch_size: int = 1 - enable_redaction: bool = False - kv_cache: bool = True - gpt_checkpoint: str = None - clvp_checkpoint: str = None - decoder_checkpoint: str = None - num_chars: int = 255 - - # XTTS GPT Encoder params - tokenizer_file: str = "" - gpt_max_audio_tokens: int = 605 - gpt_max_text_tokens: int = 402 - gpt_max_prompt_tokens: int = 70 - gpt_layers: int = 30 - gpt_n_model_channels: int = 1024 - gpt_n_heads: int = 16 - gpt_number_text_tokens: int = None - gpt_start_text_token: int = None - gpt_stop_text_token: int = None - gpt_num_audio_tokens: int = 8194 - gpt_start_audio_token: int = 8192 - gpt_stop_audio_token: int = 8193 - gpt_code_stride_len: int = 1024 - gpt_use_masking_gt_prompt_approach: bool = True - gpt_use_perceiver_resampler: bool = False - - # HifiGAN Decoder params - input_sample_rate: int = 22050 - output_sample_rate: int = 24000 - output_hop_length: int = 256 - decoder_input_dim: int = 1024 - d_vector_dim: int = 512 - cond_d_vector_in_each_upsampling_layer: bool = True - - # constants - duration_const: int = 102400 - - class Xtts(BaseTTS): """ⓍTTS model implementation. diff --git a/docs/source/models/xtts.md b/docs/source/models/xtts.md index c07d879f7c..a1309b8544 100644 --- a/docs/source/models/xtts.md +++ b/docs/source/models/xtts.md @@ -387,7 +387,7 @@ torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000) ## XttsArgs ```{eval-rst} -.. autoclass:: TTS.tts.models.xtts.XttsArgs +.. autoclass:: TTS.tts.configs.xtts_config.XttsArgs :members: ``` diff --git a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py index d31ec8f1ed..8ecbd255f3 100644 --- a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py @@ -3,8 +3,9 @@ from trainer import Trainer, TrainerArgs from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.xtts_config import XttsAudioConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig from TTS.utils.manage import ModelManager # Logging parameters diff --git a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py index ccaa97f1e4..8c384c7961 100644 --- a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py @@ -3,8 +3,9 @@ from trainer import Trainer, TrainerArgs from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.xtts_config import XttsAudioConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig from TTS.utils.manage import ModelManager # Logging parameters diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/xtts_tests/test_xtts_gpt_train.py index b8b9a4e388..55cc81d587 100644 --- a/tests/xtts_tests/test_xtts_gpt_train.py +++ b/tests/xtts_tests/test_xtts_gpt_train.py @@ -6,9 +6,10 @@ from tests import get_tests_output_path from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.xtts_config import XttsAudioConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.layers.xtts.dvae import DiscreteVAE -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig config_dataset = BaseDatasetConfig( formatter="ljspeech", diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py index 6663433c12..c4f07b42ab 100644 --- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py +++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py @@ -6,9 +6,10 @@ from tests import get_tests_output_path from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.xtts_config import XttsAudioConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.layers.xtts.dvae import DiscreteVAE -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig config_dataset = BaseDatasetConfig( formatter="ljspeech",