Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to coqui-tts-trainer 0.1.4 #51

Merged
merged 8 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ jobs:
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
- name: Install TTS
run: |
python3 -m uv pip install --system "coqui-tts[dev,server,languages] @ ."
python3 setup.py egg_info
resolution=highest
if [ "${{ matrix.python-version }}" == "3.9" ]; then
resolution=lowest-direct
fi
python3 -m uv pip install --resolution=$resolution --system "coqui-tts[dev,server,languages] @ ."
- name: Unit tests
run: make ${{ matrix.subset }}
- name: Upload coverage data
Expand Down
2 changes: 1 addition & 1 deletion TTS/bin/compute_attention_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from trainer.io import load_checkpoint

from TTS.config import load_config
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.models import setup_model
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.utils.io import load_checkpoint

if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
Expand Down
2 changes: 1 addition & 1 deletion TTS/encoder/models/base_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import torchaudio
from coqpit import Coqpit
from torch import nn
from trainer.io import load_fsspec

from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.utils.generic_utils import set_init_dict
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
5 changes: 2 additions & 3 deletions TTS/encoder/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

from coqpit import Coqpit
from trainer import TrainerArgs, get_last_checkpoint
from trainer.generic_utils import get_experiment_folder_path
from trainer.generic_utils import get_experiment_folder_path, get_git_branch
from trainer.io import copy_model_files
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger

from TTS.config import load_config, register_config
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_git_branch


@dataclass
Expand All @@ -30,7 +29,7 @@ def process_args(args, config=None):
args (argparse.Namespace or dict like): Parsed input arguments.
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
Returns:
c (TTS.utils.io.AttrDict): Config paramaters.
c (Coqpit): Config paramaters.
out_path (str): Path to save models and logging.
audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
Expand Down
3 changes: 2 additions & 1 deletion TTS/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def load_checkpoint(
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
cache (bool, optional): If True, cache the file locally for subsequent calls.
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
"""
...
3 changes: 2 additions & 1 deletion TTS/tts/configs/bark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from dataclasses import dataclass, field
from typing import Dict

from trainer.io import get_user_data_dir

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.layers.bark.model import GPTConfig
from TTS.tts.layers.bark.model_fine import FineGPTConfig
from TTS.tts.models.bark import BarkAudioConfig
from TTS.utils.generic_utils import get_user_data_dir


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/xtts/hifigan_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from torch.nn import functional as F
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec

from TTS.utils.io import load_fsspec
from TTS.vocoder.models.hifigan_generator import get_padding

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torchaudio
from coqpit import Coqpit
from torch.utils.data import DataLoader
from trainer.io import load_fsspec
from trainer.torch import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler

Expand All @@ -18,7 +19,6 @@
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/align_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from coqpit import Coqpit
from torch import nn
from trainer.io import load_fsspec

from TTS.tts.layers.align_tts.mdn import MDNBlock
from TTS.tts.layers.feed_forward.decoder import Decoder
Expand All @@ -15,7 +16,6 @@
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec


@dataclass
Expand Down
5 changes: 3 additions & 2 deletions TTS/tts/models/base_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from coqpit import Coqpit
from torch import nn
from trainer.io import load_fsspec

from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.models.base_tts import BaseTTS
Expand All @@ -15,7 +16,6 @@
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -103,7 +103,8 @@ def load_checkpoint(
config (Coqpi): model configuration.
checkpoint_path (str): path to checkpoint file.
eval (bool, optional): whether to load model for evaluation.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
cache (bool, optional): If True, cache the file locally for subsequent calls.
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/delightful_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from trainer.io import load_fsspec
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer_utils import get_optimizer, get_scheduler

Expand All @@ -32,7 +33,6 @@
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
from TTS.utils.audio.processor import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/forward_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from trainer.io import load_fsspec

from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.encoder import Encoder
Expand All @@ -17,7 +18,6 @@
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/glow_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F
from trainer.io import load_fsspec

from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.layers.glow_tts.decoder import Decoder
Expand All @@ -17,7 +18,6 @@
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/neuralhmm_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from coqpit import Coqpit
from torch import nn
from trainer.io import load_fsspec
from trainer.logging.tensorboard_logger import TensorboardLogger

from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
Expand All @@ -18,7 +19,6 @@
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/overflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from coqpit import Coqpit
from torch import nn
from trainer.io import load_fsspec
from trainer.logging.tensorboard_logger import TensorboardLogger

from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
Expand All @@ -19,7 +20,6 @@
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from trainer.io import load_fsspec
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer_utils import get_optimizer, get_scheduler

Expand All @@ -34,7 +35,6 @@
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.io import load_fsspec
from TTS.utils.samplers import BucketBatchSampler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import torch.nn.functional as F
import torchaudio
from coqpit import Coqpit
from trainer.io import load_fsspec

from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
38 changes: 0 additions & 38 deletions TTS/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,13 @@
import datetime
import importlib
import logging
import os
import re
import subprocess
import sys
from pathlib import Path
from typing import Dict, Optional

logger = logging.getLogger(__name__)


# TODO: This method is duplicated in Trainer but out of date there
def get_git_branch():
try:
out = subprocess.check_output(["git", "branch"]).decode("utf8")
current = next(line for line in out.split("\n") if line.startswith("*"))
current.replace("* ", "")
except subprocess.CalledProcessError:
current = "inside_docker"
except (FileNotFoundError, StopIteration) as e:
current = "unknown"
return current


def to_camel(text):
text = text.capitalize()
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
Expand Down Expand Up @@ -67,28 +51,6 @@ def get_import_path(obj: object) -> str:
return ".".join([type(obj).__module__, type(obj).__name__])


def get_user_data_dir(appname):
TTS_HOME = os.environ.get("TTS_HOME")
XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME")
if TTS_HOME is not None:
ans = Path(TTS_HOME).expanduser().resolve(strict=False)
elif XDG_DATA_HOME is not None:
ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False)
elif sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel

key = winreg.OpenKey(
winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
)
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
ans = Path(dir_).resolve(strict=False)
elif sys.platform == "darwin":
ans = Path("~/Library/Application Support/").expanduser()
else:
ans = Path.home().joinpath(".local/share")
return ans.joinpath(appname)


def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items():
Expand Down
70 changes: 0 additions & 70 deletions TTS/utils/io.py

This file was deleted.

Loading
Loading