Skip to content

Commit

Permalink
refactor(xtts): move configs into configs module
Browse files Browse the repository at this point in the history
Avoids circular imports
  • Loading branch information
eginhard committed Dec 2, 2024
1 parent ce20253 commit e4c44f8
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 98 deletions.
3 changes: 1 addition & 2 deletions TTS/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import torch

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.tts.configs.xtts_config import XttsArgs, XttsAudioConfig, XttsConfig
from TTS.utils.radam import RAdam

torch.serialization.add_safe_globals([dict, defaultdict, RAdam])
Expand Down
3 changes: 2 additions & 1 deletion TTS/demos/xtts_ft_demo/utils/gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from trainer import Trainer, TrainerArgs

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsAudioConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.utils.manage import ModelManager


Expand Down
86 changes: 85 additions & 1 deletion TTS/tts/configs/xtts_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,92 @@
from dataclasses import dataclass, field
from typing import List

from coqpit import Coqpit

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig


@dataclass
class XttsAudioConfig(Coqpit):
"""
Configuration class for audio-related parameters in the XTTS model.
Args:
sample_rate (int): The sample rate in which the GPT operates.
output_sample_rate (int): The sample rate of the output audio waveform.
dvae_sample_rate (int): The sample rate of the DVAE
"""

sample_rate: int = 22050
output_sample_rate: int = 24000
dvae_sample_rate: int = 22050


@dataclass
class XttsArgs(Coqpit):
"""A dataclass to represent XTTS model arguments that define the model structure.
Args:
gpt_batch_size (int): The size of the auto-regressive batch.
enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
For GPT model:
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
gpt_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402.
gpt_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70.
gpt_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30.
gpt_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024.
gpt_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16.
gpt_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255.
gpt_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255.
gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False.
gpt_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False.
gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024.
gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True.
gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False.
"""

gpt_batch_size: int = 1
enable_redaction: bool = False
kv_cache: bool = True
gpt_checkpoint: str = None
clvp_checkpoint: str = None
decoder_checkpoint: str = None
num_chars: int = 255

# XTTS GPT Encoder params
tokenizer_file: str = ""
gpt_max_audio_tokens: int = 605
gpt_max_text_tokens: int = 402
gpt_max_prompt_tokens: int = 70
gpt_layers: int = 30
gpt_n_model_channels: int = 1024
gpt_n_heads: int = 16
gpt_number_text_tokens: int = None
gpt_start_text_token: int = None
gpt_stop_text_token: int = None
gpt_num_audio_tokens: int = 8194
gpt_start_audio_token: int = 8192
gpt_stop_audio_token: int = 8193
gpt_code_stride_len: int = 1024
gpt_use_masking_gt_prompt_approach: bool = True
gpt_use_perceiver_resampler: bool = False

# HifiGAN Decoder params
input_sample_rate: int = 22050
output_sample_rate: int = 24000
output_hop_length: int = 256
decoder_input_dim: int = 1024
d_vector_dim: int = 512
cond_d_vector_in_each_upsampling_layer: bool = True

# constants
duration_const: int = 102400


@dataclass
Expand Down
9 changes: 2 additions & 7 deletions TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
from trainer.torch import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler

from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.configs.xtts_config import XttsArgs, XttsConfig
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.tts.models.xtts import Xtts
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)
Expand All @@ -34,11 +34,6 @@ class GPTTrainerConfig(XttsConfig):
test_sentences: List[dict] = field(default_factory=lambda: [])


@dataclass
class XttsAudioConfig(XttsAudioConfig):
dvae_sample_rate: int = 22050


@dataclass
class GPTArgs(XttsArgs):
min_conditioning_length: int = 66150
Expand Down
82 changes: 0 additions & 82 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -95,87 +94,6 @@ def load_audio(audiopath, sampling_rate):
return audio


@dataclass
class XttsAudioConfig(Coqpit):
"""
Configuration class for audio-related parameters in the XTTS model.
Args:
sample_rate (int): The sample rate in which the GPT operates.
output_sample_rate (int): The sample rate of the output audio waveform.
"""

sample_rate: int = 22050
output_sample_rate: int = 24000


@dataclass
class XttsArgs(Coqpit):
"""A dataclass to represent XTTS model arguments that define the model structure.
Args:
gpt_batch_size (int): The size of the auto-regressive batch.
enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
For GPT model:
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
gpt_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402.
gpt_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70.
gpt_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30.
gpt_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024.
gpt_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16.
gpt_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255.
gpt_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255.
gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False.
gpt_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False.
gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024.
gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True.
gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False.
"""

gpt_batch_size: int = 1
enable_redaction: bool = False
kv_cache: bool = True
gpt_checkpoint: str = None
clvp_checkpoint: str = None
decoder_checkpoint: str = None
num_chars: int = 255

# XTTS GPT Encoder params
tokenizer_file: str = ""
gpt_max_audio_tokens: int = 605
gpt_max_text_tokens: int = 402
gpt_max_prompt_tokens: int = 70
gpt_layers: int = 30
gpt_n_model_channels: int = 1024
gpt_n_heads: int = 16
gpt_number_text_tokens: int = None
gpt_start_text_token: int = None
gpt_stop_text_token: int = None
gpt_num_audio_tokens: int = 8194
gpt_start_audio_token: int = 8192
gpt_stop_audio_token: int = 8193
gpt_code_stride_len: int = 1024
gpt_use_masking_gt_prompt_approach: bool = True
gpt_use_perceiver_resampler: bool = False

# HifiGAN Decoder params
input_sample_rate: int = 22050
output_sample_rate: int = 24000
output_hop_length: int = 256
decoder_input_dim: int = 1024
d_vector_dim: int = 512
cond_d_vector_in_each_upsampling_layer: bool = True

# constants
duration_const: int = 102400


class Xtts(BaseTTS):
"""ⓍTTS model implementation.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/xtts.md
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000)

## XttsArgs
```{eval-rst}
.. autoclass:: TTS.tts.models.xtts.XttsArgs
.. autoclass:: TTS.tts.configs.xtts_config.XttsArgs
:members:
```

Expand Down
3 changes: 2 additions & 1 deletion recipes/ljspeech/xtts_v1/train_gpt_xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from trainer import Trainer, TrainerArgs

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsAudioConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.utils.manage import ModelManager

# Logging parameters
Expand Down
3 changes: 2 additions & 1 deletion recipes/ljspeech/xtts_v2/train_gpt_xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from trainer import Trainer, TrainerArgs

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsAudioConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.utils.manage import ModelManager

# Logging parameters
Expand Down
3 changes: 2 additions & 1 deletion tests/xtts_tests/test_xtts_gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from tests import get_tests_output_path
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsAudioConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig

config_dataset = BaseDatasetConfig(
formatter="ljspeech",
Expand Down
3 changes: 2 additions & 1 deletion tests/xtts_tests/test_xtts_v2-0_gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from tests import get_tests_output_path
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsAudioConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig

config_dataset = BaseDatasetConfig(
formatter="ljspeech",
Expand Down

0 comments on commit e4c44f8

Please sign in to comment.