diff --git a/.github/actions/setup-uv/action.yml b/.github/actions/setup-uv/action.yml
index c7dd4f5f99..88a73e8481 100644
--- a/.github/actions/setup-uv/action.yml
+++ b/.github/actions/setup-uv/action.yml
@@ -4,8 +4,9 @@ runs:
using: 'composite'
steps:
- name: Install uv
- uses: astral-sh/setup-uv@v4
+ uses: astral-sh/setup-uv@v5
with:
- version: "0.5.4"
+ version: "0.5.17"
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
+ python-version: ${{ matrix.python-version }}
diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml
index 249816a320..874da889cb 100644
--- a/.github/workflows/docker.yaml
+++ b/.github/workflows/docker.yaml
@@ -1,12 +1,29 @@
name: "Docker build and push"
on:
pull_request:
+ paths-ignore:
+ - '.gitignore'
+ - 'CITATION.cff'
+ - 'CODE_OF_CONDUCT.md'
+ - 'CONTRIBUTING.md'
+ - 'LICENSE.txt'
+ - 'README.md'
+ - 'images/**'
push:
branches:
- main
- dev
tags:
- v*
+ paths-ignore:
+ - '.gitignore'
+ - 'CITATION.cff'
+ - 'CODE_OF_CONDUCT.md'
+ - 'CONTRIBUTING.md'
+ - 'LICENSE.txt'
+ - 'README.md'
+ - 'images/**'
+
jobs:
docker-build:
name: "Build and push Docker image"
diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml
index ef74c60da6..9dad7f4120 100644
--- a/.github/workflows/pypi-release.yml
+++ b/.github/workflows/pypi-release.yml
@@ -4,8 +4,8 @@ on:
types: [published]
defaults:
run:
- shell:
- bash
+ shell: bash
+
jobs:
build:
runs-on: ubuntu-latest
diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index d1060f6be2..a13695c5d2 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -4,20 +4,31 @@ on:
push:
branches:
- main
+ paths-ignore:
+ - '.gitignore'
+ - 'CITATION.cff'
+ - 'CODE_OF_CONDUCT.md'
+ - 'CONTRIBUTING.md'
+ - 'LICENSE.txt'
+ - 'README.md'
+ - 'images/**'
pull_request:
types: [opened, synchronize, reopened]
+ paths-ignore:
+ - '.gitignore'
+ - 'CITATION.cff'
+ - 'CODE_OF_CONDUCT.md'
+ - 'CONTRIBUTING.md'
+ - 'LICENSE.txt'
+ - 'README.md'
+ - 'images/**'
+
jobs:
lint:
runs-on: ubuntu-latest
- strategy:
- fail-fast: false
- matrix:
- python-version: [3.9]
steps:
- uses: actions/checkout@v4
- name: Setup uv
uses: ./.github/actions/setup-uv
- - name: Set up Python ${{ matrix.python-version }}
- run: uv python install ${{ matrix.python-version }}
- name: Lint check
run: make lint
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 7905add3f7..4d178b0bad 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -4,8 +4,24 @@ on:
push:
branches:
- main
+ paths-ignore:
+ - '.gitignore'
+ - 'CITATION.cff'
+ - 'CODE_OF_CONDUCT.md'
+ - 'CONTRIBUTING.md'
+ - 'LICENSE.txt'
+ - 'README.md'
+ - 'images/**'
pull_request:
types: [opened, synchronize, reopened]
+ paths-ignore:
+ - '.gitignore'
+ - 'CITATION.cff'
+ - 'CODE_OF_CONDUCT.md'
+ - 'CONTRIBUTING.md'
+ - 'LICENSE.txt'
+ - 'README.md'
+ - 'images/**'
workflow_dispatch:
inputs:
trainer_branch:
@@ -16,20 +32,27 @@ on:
description: "Branch of Coqpit to test"
required: false
default: "main"
+ paths-ignore:
+ - '.gitignore'
+ - 'CITATION.cff'
+ - 'CODE_OF_CONDUCT.md'
+ - 'CONTRIBUTING.md'
+ - 'LICENSE.txt'
+ - 'README.md'
+ - 'images/**'
+
jobs:
unit:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- python-version: [3.9, "3.10", "3.11", "3.12"]
+ python-version: ["3.10", "3.11", "3.12"]
subset: ["data_tests", "inference_tests", "test_aux", "test_text"]
steps:
- uses: actions/checkout@v4
- name: Setup uv
uses: ./.github/actions/setup-uv
- - name: Set up Python ${{ matrix.python-version }}
- run: uv python install ${{ matrix.python-version }}
- name: Install Espeak
if: contains(fromJSON('["inference_tests", "test_text"]'), matrix.subset)
run: |
@@ -50,7 +73,7 @@ jobs:
- name: Unit tests
run: |
resolution=highest
- if [ "${{ matrix.python-version }}" == "3.9" ]; then
+ if [ "${{ matrix.python-version }}" == "3.10" ]; then
resolution=lowest-direct
fi
uv run --resolution=$resolution --extra server --extra languages make ${{ matrix.subset }}
@@ -60,22 +83,18 @@ jobs:
include-hidden-files: true
name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }}
path: .coverage.*
- if-no-files-found: ignore
integration:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- python-version: ["3.9", "3.12"]
- subset: ["test_tts", "test_tts2", "test_vocoder", "test_xtts"]
+ python-version: ["3.10", "3.12"]
+ shard: [0, 1, 2, 3, 4]
steps:
- uses: actions/checkout@v4
- name: Setup uv
uses: ./.github/actions/setup-uv
- - name: Set up Python ${{ matrix.python-version }}
- run: uv python install ${{ matrix.python-version }}
- name: Install Espeak
- if: contains(fromJSON('["test_tts", "test_tts2", "test_xtts"]'), matrix.subset)
run: |
sudo apt-get update
sudo apt-get install espeak espeak-ng
@@ -91,20 +110,22 @@ jobs:
if [[ -n "${{ github.event.inputs.coqpit_branch }}" ]]; then
uv add git+https://github.com/idiap/coqui-ai-coqpit --branch ${{ github.event.inputs.coqpit_branch }}
fi
- - name: Integration tests
+ - name: Integration tests for shard ${{ matrix.shard }}
run: |
+ uv run pytest tests/integration --collect-only --quiet | grep "::" > integration_tests.txt
+ total_shards=5
+ shard_tests=$(awk "NR % $total_shards == ${{ matrix.shard }}" integration_tests.txt)
resolution=highest
- if [ "${{ matrix.python-version }}" == "3.9" ]; then
+ if [ "${{ matrix.python-version }}" == "3.10" ]; then
resolution=lowest-direct
fi
- uv run --resolution=$resolution --extra server --extra languages make ${{ matrix.subset }}
+ uv run --resolution=$resolution --extra languages coverage run -m pytest -x -v --durations=0 $shard_tests
- name: Upload coverage data
uses: actions/upload-artifact@v4
with:
include-hidden-files: true
- name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }}
+ name: coverage-data-integration-${{ matrix.shard }}-${{ matrix.python-version }}
path: .coverage.*
- if-no-files-found: ignore
zoo:
runs-on: ubuntu-latest
strategy:
@@ -116,8 +137,6 @@ jobs:
- uses: actions/checkout@v4
- name: Setup uv
uses: ./.github/actions/setup-uv
- - name: Set up Python ${{ matrix.python-version }}
- run: uv python install ${{ matrix.python-version }}
- name: Install Espeak
run: |
sudo apt-get update
@@ -145,7 +164,6 @@ jobs:
include-hidden-files: true
name: coverage-data-zoo-${{ matrix.partition }}
path: .coverage.*
- if-no-files-found: ignore
coverage:
if: always()
needs: [unit, integration, zoo]
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 62420e9958..2f070ad085 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -7,13 +7,9 @@ repos:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- - repo: "https://github.com/psf/black"
- rev: 24.2.0
- hooks:
- - id: black
- language_version: python3
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.7.0
+ rev: v0.9.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
+ - id: ruff-format
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 2b3a973763..5fe9421442 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -88,7 +88,7 @@ curl -LsSf https://astral.sh/uv/install.sh | sh
uv run make test_all # run all the tests, report all the errors
```
-9. Format your code. We use ```black``` for code formatting.
+9. Format your code. We use ```ruff``` for code formatting.
```bash
make style
diff --git a/Makefile b/Makefile
index 35345b8c1f..da714e7b34 100644
--- a/Makefile
+++ b/Makefile
@@ -15,12 +15,6 @@ test_vocoder: ## run vocoder tests.
test_tts: ## run tts tests.
coverage run -m pytest -x -v --durations=0 tests/tts_tests
-test_tts2: ## run tts tests.
- coverage run -m pytest -x -v --durations=0 tests/tts_tests2
-
-test_xtts:
- coverage run -m pytest -x -v --durations=0 tests/xtts_tests
-
test_aux: ## run aux tests.
coverage run -m pytest -x -v --durations=0 tests/aux_tests
@@ -43,11 +37,11 @@ test_failed: ## only run tests failed the last time.
coverage run -m pytest -x -v --last-failed tests
style: ## update code style.
- uv run --only-dev black ${target_dirs}
+ uv run --only-dev ruff format ${target_dirs}
lint: ## run linters.
uv run --only-dev ruff check ${target_dirs}
- uv run --only-dev black ${target_dirs} --check
+ uv run --only-dev ruff format ${target_dirs} --check
system-deps: ## install linux system deps
sudo apt-get install -y libsndfile1-dev
diff --git a/README.md b/README.md
index c0843b731d..db8868b26d 100644
--- a/README.md
+++ b/README.md
@@ -116,7 +116,7 @@ You can also help us implement more models.
## Installation
-🐸TTS is tested on Ubuntu 24.04 with **python >= 3.9, < 3.13**, but should also
+🐸TTS is tested on Ubuntu 24.04 with **python >= 3.10, < 3.13**, but should also
work on Mac and Windows.
If you are only interested in [synthesizing speech](https://coqui-tts.readthedocs.io/en/latest/inference.html) with the pretrained 🐸TTS models, installing from PyPI is the easiest option.
diff --git a/TTS/.models.json b/TTS/.models.json
index 05c88bef43..624a6a0489 100644
--- a/TTS/.models.json
+++ b/TTS/.models.json
@@ -723,6 +723,17 @@
"description": "persian-tts-female-glow_tts model for text to speech purposes. Single-speaker female voice Trained on persian-tts-dataset-famale. \nThis model has no compatible vocoder thus the output quality is not very good. \nDataset: https://www.kaggle.com/datasets/magnoliasis/persian-tts-dataset-famale.",
"author": "@karim23657",
"license": "CC-BY-4.0"
+ },
+ "vits-female": {
+ "hf_url": [
+ "https://huggingface.co/Kamtera/persian-tts-female-vits/resolve/main/best_model_30824.pth",
+ "https://huggingface.co/Kamtera/persian-tts-female-vits/resolve/main/config.json"
+ ],
+ "default_vocoder": null,
+ "commit": null,
+ "description": "persian-tts-female-vits model for text to speech purposes. Single-speaker female voice trained on persian-tts-dataset-female.\nDataset: https://www.kaggle.com/datasets/magnoliasis/persian-tts-dataset-famale.",
+ "author": "@karim23657",
+ "license": "openrail"
}
}
},
diff --git a/TTS/__init__.py b/TTS/__init__.py
index 8e93c9b5db..d270e09e22 100644
--- a/TTS/__init__.py
+++ b/TTS/__init__.py
@@ -4,6 +4,15 @@
__version__ = importlib.metadata.version("coqui-tts")
+if "coqpit" in importlib.metadata.packages_distributions().get("coqpit", []):
+ msg = (
+ "coqui-tts switched to a forked version of Coqpit, but you still have the original "
+ "package installed. Run the following to avoid conflicts:\n"
+ " pip uninstall coqpit\n"
+ " pip install coqpit-config"
+ )
+ raise ImportError(msg)
+
if is_pytorch_at_least_2_4():
import _codecs
diff --git a/TTS/api.py b/TTS/api.py
index 6db929411c..bf91df4c6d 100644
--- a/TTS/api.py
+++ b/TTS/api.py
@@ -4,7 +4,6 @@
import tempfile
import warnings
from pathlib import Path
-from typing import Optional, Union
from torch import nn
@@ -22,15 +21,15 @@ def __init__(
self,
model_name: str = "",
*,
- model_path: Optional[str] = None,
- config_path: Optional[str] = None,
- vocoder_name: Optional[str] = None,
- vocoder_path: Optional[str] = None,
- vocoder_config_path: Optional[str] = None,
- encoder_path: Optional[str] = None,
- encoder_config_path: Optional[str] = None,
- speakers_file_path: Optional[str] = None,
- language_ids_file_path: Optional[str] = None,
+ model_path: str | None = None,
+ config_path: str | None = None,
+ vocoder_name: str | None = None,
+ vocoder_path: str | None = None,
+ vocoder_config_path: str | None = None,
+ encoder_path: str | None = None,
+ encoder_config_path: str | None = None,
+ speakers_file_path: str | None = None,
+ language_ids_file_path: str | None = None,
progress_bar: bool = True,
gpu: bool = False,
) -> None:
@@ -77,8 +76,8 @@ def __init__(
super().__init__()
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar)
self.config = load_config(config_path) if config_path else None
- self.synthesizer: Optional[Synthesizer] = None
- self.voice_converter: Optional[Synthesizer] = None
+ self.synthesizer: Synthesizer | None = None
+ self.voice_converter: Synthesizer | None = None
self.model_name = ""
self.vocoder_path = vocoder_path
@@ -156,10 +155,18 @@ def list_models() -> list[str]:
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()
def download_model_by_name(
- self, model_name: str, vocoder_name: Optional[str] = None
- ) -> tuple[Optional[Path], Optional[Path], Optional[Path], Optional[Path], Optional[Path]]:
+ self, model_name: str, vocoder_name: str | None = None
+ ) -> tuple[Path | None, Path | None, Path | None, Path | None, Path | None]:
model_path, config_path, model_item = self.manager.download_model(model_name)
- if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
+ if (
+ "fairseq" in model_name
+ or "openvoice" in model_name
+ or (
+ model_item is not None
+ and isinstance(model_item["model_url"], list)
+ and len(model_item["model_url"]) > 2
+ )
+ ):
# return model directory if there are multiple files
# we assume that the model knows how to load itself
return None, None, None, None, model_path
@@ -176,7 +183,7 @@ def download_model_by_name(
vocoder_path, vocoder_config_path, _ = self.manager.download_model(vocoder_name)
return model_path, config_path, vocoder_path, vocoder_config_path, None
- def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
+ def load_model_by_name(self, model_name: str, vocoder_name: str | None = None, *, gpu: bool = False) -> None:
"""Load one of the 🐸TTS models by name.
Args:
@@ -185,7 +192,7 @@ def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None
"""
self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu)
- def load_vc_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
+ def load_vc_model_by_name(self, model_name: str, vocoder_name: str | None = None, *, gpu: bool = False) -> None:
"""Load one of the voice conversion models by name.
Args:
@@ -205,7 +212,7 @@ def load_vc_model_by_name(self, model_name: str, vocoder_name: Optional[str] = N
use_cuda=gpu,
)
- def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
+ def load_tts_model_by_name(self, model_name: str, vocoder_name: str | None = None, *, gpu: bool = False) -> None:
"""Load one of 🐸TTS models by name.
Args:
@@ -261,11 +268,10 @@ def load_tts_model_by_path(self, model_path: str, config_path: str, *, gpu: bool
def _check_arguments(
self,
- speaker: Optional[str] = None,
- language: Optional[str] = None,
- speaker_wav: Optional[str] = None,
- emotion: Optional[str] = None,
- speed: Optional[float] = None,
+ speaker: str | None = None,
+ language: str | None = None,
+ speaker_wav: str | None = None,
+ emotion: str | None = None,
**kwargs,
) -> None:
"""Check if the arguments are valid for the model."""
@@ -278,17 +284,16 @@ def _check_arguments(
raise ValueError("Model is not multi-speaker but `speaker` is provided.")
if not self.is_multi_lingual and language is not None:
raise ValueError("Model is not multi-lingual but `language` is provided.")
- if emotion is not None and speed is not None:
- raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.")
+ if emotion is not None:
+ raise ValueError("Emotion can only be used with Coqui Studio models. Which is discontinued.")
def tts(
self,
text: str,
- speaker: Optional[str] = None,
- language: Optional[str] = None,
- speaker_wav: Optional[str] = None,
- emotion: Optional[str] = None,
- speed: Optional[float] = None,
+ speaker: str | None = None,
+ language: str | None = None,
+ speaker_wav: str | None = None,
+ emotion: str | None = None,
split_sentences: bool = True,
**kwargs,
):
@@ -307,9 +312,6 @@ def tts(
Defaults to None.
emotion (str, optional):
Emotion to use for 🐸Coqui Studio models. If None, Studio models use "Neutral". Defaults to None.
- speed (float, optional):
- Speed factor to use for 🐸Coqui Studio models, between 0 and 2.0. If None, Studio models use 1.0.
- Defaults to None.
split_sentences (bool, optional):
Split text into sentences, synthesize them separately and concatenate the file audio.
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
@@ -317,9 +319,7 @@ def tts(
kwargs (dict, optional):
Additional arguments for the model.
"""
- self._check_arguments(
- speaker=speaker, language=language, speaker_wav=speaker_wav, emotion=emotion, speed=speed, **kwargs
- )
+ self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, emotion=emotion, **kwargs)
wav = self.synthesizer.tts(
text=text,
speaker_name=speaker,
@@ -333,11 +333,10 @@ def tts(
def tts_to_file(
self,
text: str,
- speaker: Optional[str] = None,
- language: Optional[str] = None,
- speaker_wav: Optional[str] = None,
- emotion: Optional[str] = None,
- speed: float = 1.0,
+ speaker: str | None = None,
+ language: str | None = None,
+ speaker_wav: str | None = None,
+ emotion: str | None = None,
pipe_out=None,
file_path: str = "output.wav",
split_sentences: bool = True,
@@ -359,8 +358,6 @@ def tts_to_file(
Defaults to None.
emotion (str, optional):
Emotion to use for 🐸Coqui Studio models. Defaults to "Neutral".
- speed (float, optional):
- Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0. Defaults to None.
pipe_out (BytesIO, optional):
Flag to stdout the generated TTS wav file for shell pipe.
file_path (str, optional):
@@ -388,7 +385,7 @@ def tts_to_file(
def voice_conversion(
self,
source_wav: str,
- target_wav: Union[str, list[str]],
+ target_wav: str | list[str],
):
"""Voice conversion with FreeVC. Convert source wav to target speaker.
@@ -406,7 +403,7 @@ def voice_conversion(
def voice_conversion_to_file(
self,
source_wav: str,
- target_wav: Union[str, list[str]],
+ target_wav: str | list[str],
file_path: str = "output.wav",
pipe_out=None,
) -> str:
@@ -430,9 +427,9 @@ def tts_with_vc(
self,
text: str,
*,
- language: Optional[str] = None,
- speaker_wav: Union[str, list[str]],
- speaker: Optional[str] = None,
+ language: str | None = None,
+ speaker_wav: str | list[str],
+ speaker: str | None = None,
split_sentences: bool = True,
):
"""Convert text to speech with voice conversion.
@@ -473,10 +470,10 @@ def tts_with_vc_to_file(
self,
text: str,
*,
- language: Optional[str] = None,
- speaker_wav: Union[str, list[str]],
+ language: str | None = None,
+ speaker_wav: str | list[str],
file_path: str = "output.wav",
- speaker: Optional[str] = None,
+ speaker: str | None = None,
split_sentences: bool = True,
pipe_out=None,
) -> str:
diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py
index f103350912..d450e26fba 100644
--- a/TTS/bin/compute_embeddings.py
+++ b/TTS/bin/compute_embeddings.py
@@ -3,7 +3,6 @@
import os
import sys
from argparse import RawTextHelpFormatter
-from typing import Optional
import torch
from tqdm import tqdm
@@ -16,7 +15,7 @@
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
-def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
+def parse_args(arg_list: list[str] | None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n"""
"""
@@ -185,7 +184,7 @@ def compute_embeddings(
print("Speaker embeddings saved at:", mapping_file_path)
-def main(arg_list: Optional[list[str]] = None):
+def main(arg_list: list[str] | None = None):
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())
args = parse_args(arg_list)
diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py
index b7c52ac6c5..1da7a092fb 100755
--- a/TTS/bin/compute_statistics.py
+++ b/TTS/bin/compute_statistics.py
@@ -1,12 +1,10 @@
#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
import argparse
import glob
import logging
import os
import sys
-from typing import Optional
import numpy as np
from tqdm import tqdm
@@ -18,7 +16,7 @@
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
-def parse_args(arg_list: Optional[list[str]]) -> tuple[argparse.Namespace, list[str]]:
+def parse_args(arg_list: list[str] | None) -> tuple[argparse.Namespace, list[str]]:
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.")
parser.add_argument("out_path", type=str, help="save path (directory and filename).")
@@ -31,7 +29,7 @@ def parse_args(arg_list: Optional[list[str]]) -> tuple[argparse.Namespace, list[
return parser.parse_known_args(arg_list)
-def main(arg_list: Optional[list[str]] = None):
+def main(arg_list: list[str] | None = None):
"""Run preprocessing process."""
setup_logger("TTS", level=logging.INFO, stream=sys.stderr, formatter=ConsoleFormatter())
args, overrides = parse_args(arg_list)
diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py
index 77072f9efa..be9387f015 100755
--- a/TTS/bin/extract_tts_spectrograms.py
+++ b/TTS/bin/extract_tts_spectrograms.py
@@ -5,7 +5,6 @@
import logging
import sys
from pathlib import Path
-from typing import Optional
import numpy as np
import torch
@@ -27,7 +26,7 @@
use_cuda = torch.cuda.is_available()
-def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
+def parse_args(arg_list: list[str] | None) -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
@@ -244,7 +243,7 @@ def extract_spectrograms(
f.write(f"{data[0] / data[1]}.npy\n")
-def main(arg_list: Optional[list[str]] = None) -> None:
+def main(arg_list: list[str] | None = None) -> None:
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())
args = parse_args(arg_list)
config = load_config(args.config_path)
diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py
index 0c453db85b..40afa1456c 100644
--- a/TTS/bin/find_unique_phonemes.py
+++ b/TTS/bin/find_unique_phonemes.py
@@ -5,7 +5,6 @@
import multiprocessing
import sys
from argparse import RawTextHelpFormatter
-from typing import Optional
from tqdm.contrib.concurrent import process_map
@@ -21,7 +20,7 @@ def compute_phonemes(item: dict) -> set[str]:
return set(ph)
-def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
+def parse_args(arg_list: list[str] | None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
"""
@@ -35,7 +34,7 @@ def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
return parser.parse_args(arg_list)
-def main(arg_list: Optional[list[str]] = None) -> None:
+def main(arg_list: list[str] | None = None) -> None:
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())
global phonemizer
args = parse_args(arg_list)
diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py
index f963485c5d..00d7530427 100755
--- a/TTS/bin/synthesize.py
+++ b/TTS/bin/synthesize.py
@@ -7,7 +7,6 @@
import logging
import sys
from argparse import RawTextHelpFormatter
-from typing import Optional
# pylint: disable=redefined-outer-name, unused-argument
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
@@ -135,7 +134,7 @@
"""
-def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
+def parse_args(arg_list: list[str] | None) -> argparse.Namespace:
"""Parse arguments."""
parser = argparse.ArgumentParser(
description=description.replace(" ```\n", ""),
@@ -311,7 +310,7 @@ def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
return args
-def main(arg_list: Optional[list[str]] = None) -> None:
+def main(arg_list: list[str] | None = None) -> None:
"""Entry point for `tts` command line interface."""
args = parse_args(arg_list)
stream = sys.stderr if args.pipe_out else sys.stdout
diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py
index a37ab8efc9..06189a44c3 100644
--- a/TTS/bin/train_encoder.py
+++ b/TTS/bin/train_encoder.py
@@ -1,25 +1,30 @@
#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
+
+# TODO: use Trainer
import logging
import os
import sys
import time
-import traceback
import warnings
+from dataclasses import dataclass, field
import torch
from torch.utils.data import DataLoader
-from trainer.generic_utils import count_parameters, remove_experiment_folder
-from trainer.io import copy_model_files, save_best_model, save_checkpoint
+from trainer import TrainerArgs, TrainerConfig
+from trainer.generic_utils import count_parameters, get_experiment_folder_path, get_git_branch
+from trainer.io import copy_model_files, get_last_checkpoint, save_best_model, save_checkpoint
+from trainer.logging import BaseDashboardLogger, ConsoleLogger, logger_factory
from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer
+from TTS.config import load_config, register_config
+from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import setup_encoder_model
-from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
+from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.utils.samplers import PerfectBatchSampler
@@ -34,7 +39,77 @@
print(" > Number of GPUs: ", num_gpus)
-def setup_loader(ap: AudioProcessor, is_val: bool = False):
+@dataclass
+class TrainArgs(TrainerArgs):
+ config_path: str | None = field(default=None, metadata={"help": "Path to the config file."})
+
+
+def process_args(
+ args, config: BaseEncoderConfig | None = None
+) -> tuple[BaseEncoderConfig, str, str, ConsoleLogger, BaseDashboardLogger | None]:
+ """Process parsed comand line arguments and initialize the config if not provided.
+ Args:
+ args (argparse.Namespace or dict like): Parsed input arguments.
+ config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
+ Returns:
+ c (Coqpit): Config paramaters.
+ out_path (str): Path to save models and logging.
+ audio_path (str): Path to save generated test audios.
+ c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
+ logging to the console.
+ dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
+ TODO:
+ - Interactive config definition.
+ """
+ coqpit_overrides = None
+ if isinstance(args, tuple):
+ args, coqpit_overrides = args
+ if args.continue_path:
+ # continue a previous training from its output folder
+ experiment_path = args.continue_path
+ args.config_path = os.path.join(args.continue_path, "config.json")
+ args.restore_path, best_model = get_last_checkpoint(args.continue_path)
+ if not args.best_path:
+ args.best_path = best_model
+ # init config if not already defined
+ if config is None:
+ if args.config_path:
+ # init from a file
+ config = load_config(args.config_path)
+ else:
+ # init from console args
+ from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
+
+ config_base = BaseTrainingConfig()
+ config_base.parse_known_args(coqpit_overrides)
+ config = register_config(config_base.model)()
+ # override values from command-line args
+ config.parse_known_args(coqpit_overrides, relaxed_parser=True)
+ experiment_path = args.continue_path
+ if not experiment_path:
+ experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
+ audio_path = os.path.join(experiment_path, "test_audios")
+ config.output_log_path = experiment_path
+ # setup rank 0 process in distributed training
+ dashboard_logger = None
+ if args.rank == 0:
+ new_fields = {}
+ if args.restore_path:
+ new_fields["restore_path"] = args.restore_path
+ new_fields["github_branch"] = get_git_branch()
+ # if model characters are not set in the config file
+ # save the default set to the config file for future
+ # compatibility.
+ if config.has("characters") and config.characters is None:
+ used_characters = parse_symbols()
+ new_fields["characters"] = used_characters
+ copy_model_files(config, experiment_path, new_fields)
+ dashboard_logger = logger_factory(config, experiment_path)
+ c_logger = ConsoleLogger()
+ return config, experiment_path, audio_path, c_logger, dashboard_logger
+
+
+def setup_loader(c: TrainerConfig, ap: AudioProcessor, is_val: bool = False):
num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class
num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch
@@ -84,7 +159,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False):
return loader, classes, dataset.get_map_classid_to_classname()
-def evaluation(model, criterion, data_loader, global_step):
+def evaluation(c: BaseEncoderConfig, model, criterion, data_loader, global_step, dashboard_logger: BaseDashboardLogger):
eval_loss = 0
for _, data in enumerate(data_loader):
with torch.inference_mode():
@@ -128,7 +203,17 @@ def evaluation(model, criterion, data_loader, global_step):
return eval_avg_loss
-def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
+def train(
+ c: BaseEncoderConfig,
+ model,
+ optimizer,
+ scheduler,
+ criterion,
+ data_loader,
+ eval_data_loader,
+ global_step,
+ dashboard_logger: BaseDashboardLogger,
+):
model.train()
best_loss = {"train_loss": None, "eval_loss": float("inf")}
avg_loader_time = 0
@@ -219,37 +304,33 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
if global_step % c.print_step == 0:
print(
- " | > Step:{} Loss:{:.5f} GradNorm:{:.5f} "
- "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
- global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr
- ),
+ f" | > Step:{global_step} Loss:{loss.item():.5f} GradNorm:{grad_norm:.5f} "
+ f"StepTime:{step_time:.2f} LoaderTime:{loader_time:.2f} AvGLoaderTime:{avg_loader_time:.2f} LR:{current_lr:.6f}",
flush=True,
)
if global_step % c.save_step == 0:
# save model
save_checkpoint(
- c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
+ c, model, optimizer, None, global_step, epoch, c.output_log_path, criterion=criterion.state_dict()
)
end_time = time.time()
print("")
print(
- ">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} "
- "EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format(
- epoch, tot_loss / len(data_loader), grad_norm, epoch_time, avg_loader_time
- ),
+ f">>> Epoch:{epoch} AvgLoss: {tot_loss / len(data_loader):.5f} GradNorm:{grad_norm:.5f} "
+ f"EpochTime:{epoch_time:.2f} AvGLoaderTime:{avg_loader_time:.2f} ",
flush=True,
)
# evaluation
if c.run_eval:
model.eval()
- eval_loss = evaluation(model, criterion, eval_data_loader, global_step)
+ eval_loss = evaluation(c, model, criterion, eval_data_loader, global_step, dashboard_logger)
print("\n\n")
print("--> EVAL PERFORMANCE")
print(
- " | > Epoch:{} AvgLoss: {:.5f} ".format(epoch, eval_loss),
+ f" | > Epoch:{epoch} AvgLoss: {eval_loss:.5f} ",
flush=True,
)
# save the best checkpoint
@@ -262,7 +343,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
None,
global_step,
epoch,
- OUT_PATH,
+ c.output_log_path,
criterion=criterion.state_dict(),
)
model.train()
@@ -270,7 +351,13 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
return best_loss, global_step
-def main(args): # pylint: disable=redefined-outer-name
+def main(arg_list: list[str] | None = None):
+ setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())
+
+ train_config = TrainArgs()
+ parser = train_config.init_argparse(arg_prefix="")
+ args, overrides = parser.parse_known_args(arg_list)
+ c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args((args, overrides))
# pylint: disable=global-variable-undefined
global meta_data_train
global meta_data_eval
@@ -284,9 +371,9 @@ def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)
- train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False)
+ train_data_loader, train_classes, map_classid_to_classname = setup_loader(c, ap, is_val=False)
if c.run_eval:
- eval_data_loader, _, _ = setup_loader(ap, is_val=True)
+ eval_data_loader, _, _ = setup_loader(c, ap, is_val=True)
else:
eval_data_loader = None
@@ -301,7 +388,7 @@ def main(args): # pylint: disable=redefined-outer-name
criterion, args.restore_step = model.load_checkpoint(
c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion
)
- print(" > Model restored from step %d" % args.restore_step, flush=True)
+ print(f" > Model restored from step {args.restore_step}", flush=True)
else:
args.restore_step = 0
@@ -311,30 +398,30 @@ def main(args): # pylint: disable=redefined-outer-name
scheduler = None
num_params = count_parameters(model)
- print("\n > Model has {} parameters".format(num_params), flush=True)
+ print(f"\n > Model has {num_params} parameters", flush=True)
if use_cuda:
model = model.cuda()
criterion.cuda()
global_step = args.restore_step
- _, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step)
+ _, global_step = train(
+ c, model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step, dashboard_logger
+ )
+ sys.exit(0)
if __name__ == "__main__":
- setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())
-
- args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training()
-
- try:
- main(args)
- except KeyboardInterrupt:
- remove_experiment_folder(OUT_PATH)
- try:
- sys.exit(0)
- except SystemExit:
- os._exit(0) # pylint: disable=protected-access
- except Exception: # pylint: disable=broad-except
- remove_experiment_folder(OUT_PATH)
- traceback.print_exc()
- sys.exit(1)
+ main()
+ # try:
+ # main()
+ # except KeyboardInterrupt:
+ # remove_experiment_folder(OUT_PATH)
+ # try:
+ # sys.exit(0)
+ # except SystemExit:
+ # os._exit(0) # pylint: disable=protected-access
+ # except Exception: # pylint: disable=broad-except
+ # remove_experiment_folder(OUT_PATH)
+ # traceback.print_exc()
+ # sys.exit(1)
diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py
index e93b1c9d24..deaa350878 100644
--- a/TTS/bin/train_tts.py
+++ b/TTS/bin/train_tts.py
@@ -16,7 +16,7 @@ class TrainTTSArgs(TrainerArgs):
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
-def main():
+def main(arg_list: list[str] | None = None):
"""Run `tts` model training directly by a `config.json` file."""
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())
@@ -24,8 +24,8 @@ def main():
train_args = TrainTTSArgs()
parser = train_args.init_argparse(arg_prefix="")
- # override trainer args from comman-line args
- args, config_overrides = parser.parse_known_args()
+ # override trainer args from command-line args
+ args, config_overrides = parser.parse_known_args(arg_list)
train_args.parse_args(args)
# load config.json and register
@@ -70,6 +70,7 @@ def main():
parse_command_line_args=False,
)
trainer.fit()
+ sys.exit(0)
if __name__ == "__main__":
diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py
index 7cf5696237..58122b9005 100644
--- a/TTS/bin/train_vocoder.py
+++ b/TTS/bin/train_vocoder.py
@@ -2,7 +2,6 @@
import os
import sys
from dataclasses import dataclass, field
-from typing import Optional
from trainer import Trainer, TrainerArgs
@@ -18,7 +17,7 @@ class TrainVocoderArgs(TrainerArgs):
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
-def main(arg_list: Optional[list[str]] = None):
+def main(arg_list: list[str] | None = None):
"""Run `tts` model training directly by a `config.json` file."""
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())
diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py
index e5f40c0296..401003504e 100644
--- a/TTS/config/__init__.py
+++ b/TTS/config/__init__.py
@@ -1,7 +1,7 @@
import json
import os
import re
-from typing import Any, Dict, Union
+from typing import Any, Union
import fsspec
import yaml
@@ -54,11 +54,11 @@ def register_config(model_name: str) -> Coqpit:
return config_class
-def _process_model_name(config_dict: Dict) -> str:
+def _process_model_name(config_dict: dict) -> str:
"""Format the model name as expected. It is a band-aid for the old `vocoder` model names.
Args:
- config_dict (Dict): A dictionary including the config fields.
+ config_dict (dict): A dictionary including the config fields.
Returns:
str: Formatted modelname.
@@ -68,7 +68,7 @@ def _process_model_name(config_dict: Dict) -> str:
return model_name
-def load_config(config_path: Union[str, os.PathLike[Any]]) -> Coqpit:
+def load_config(config_path: str | os.PathLike[Any]) -> Coqpit:
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
to find the corresponding Config class. Then initialize the Config.
diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py
index 7fae77d613..a0a013b0de 100644
--- a/TTS/config/shared_configs.py
+++ b/TTS/config/shared_configs.py
@@ -1,5 +1,4 @@
from dataclasses import asdict, dataclass
-from typing import List
from coqpit import Coqpit, check_argument
from trainer import TrainerConfig
@@ -227,7 +226,7 @@ class BaseDatasetConfig(Coqpit):
dataset_name: str = ""
path: str = ""
meta_file_train: str = ""
- ignored_speakers: List[str] = None
+ ignored_speakers: list[str] = None
language: str = ""
phonemizer: str = ""
meta_file_val: str = ""
diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py
index 7ac38ed6ee..dac5f0870a 100644
--- a/TTS/demos/xtts_ft_demo/xtts_demo.py
+++ b/TTS/demos/xtts_ft_demo/xtts_demo.py
@@ -104,7 +104,7 @@ def isatty(self):
def read_logs():
sys.stdout.flush()
- with open(sys.stdout.log_file, "r") as f:
+ with open(sys.stdout.log_file) as f:
return f.read()
diff --git a/TTS/encoder/configs/base_encoder_config.py b/TTS/encoder/configs/base_encoder_config.py
index ebbaa0457b..d2d0ef580d 100644
--- a/TTS/encoder/configs/base_encoder_config.py
+++ b/TTS/encoder/configs/base_encoder_config.py
@@ -1,5 +1,4 @@
from dataclasses import asdict, dataclass, field
-from typing import Dict, List
from coqpit import MISSING
@@ -12,9 +11,9 @@ class BaseEncoderConfig(BaseTrainingConfig):
model: str = None
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
- datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
+ datasets: list[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# model params
- model_params: Dict = field(
+ model_params: dict = field(
default_factory=lambda: {
"model_name": "lstm",
"input_dim": 80,
@@ -25,7 +24,7 @@ class BaseEncoderConfig(BaseTrainingConfig):
}
)
- audio_augmentation: Dict = field(default_factory=lambda: {})
+ audio_augmentation: dict = field(default_factory=lambda: {})
# training params
epochs: int = 10000
@@ -33,7 +32,7 @@ class BaseEncoderConfig(BaseTrainingConfig):
grad_clip: float = 3.0
lr: float = 0.0001
optimizer: str = "radam"
- optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0})
+ optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0})
lr_decay: bool = False
warmup_steps: int = 4000
@@ -56,6 +55,6 @@ class BaseEncoderConfig(BaseTrainingConfig):
def check_values(self):
super().check_values()
c = asdict(self)
- assert (
- c["model_params"]["input_dim"] == self.audio.num_mels
- ), " [!] model input dimendion must be equal to melspectrogram dimension."
+ assert c["model_params"]["input_dim"] == self.audio.num_mels, (
+ " [!] model input dimendion must be equal to melspectrogram dimension."
+ )
diff --git a/TTS/encoder/models/base_encoder.py b/TTS/encoder/models/base_encoder.py
index 603481cc56..c6680c3a25 100644
--- a/TTS/encoder/models/base_encoder.py
+++ b/TTS/encoder/models/base_encoder.py
@@ -34,7 +34,7 @@ class BaseEncoder(nn.Module):
# pylint: disable=W0102
def __init__(self):
- super(BaseEncoder, self).__init__()
+ super().__init__()
def get_torch_mel_spectrogram_class(self, audio_config):
return torch.nn.Sequential(
@@ -107,7 +107,7 @@ def get_criterion(self, c: Coqpit, num_classes=None):
elif c.loss == "softmaxproto":
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
else:
- raise Exception("The %s not is a loss supported" % c.loss)
+ raise Exception(f"The {c.loss} not is a loss supported")
return criterion
def load_checkpoint(
diff --git a/TTS/encoder/models/resnet.py b/TTS/encoder/models/resnet.py
index 5eafcd6005..d7f3a2f4bd 100644
--- a/TTS/encoder/models/resnet.py
+++ b/TTS/encoder/models/resnet.py
@@ -7,7 +7,7 @@
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
- super(SELayer, self).__init__()
+ super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
@@ -27,7 +27,7 @@ class SEBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
- super(SEBasicBlock, self).__init__()
+ super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
@@ -73,7 +73,7 @@ def __init__(
use_torch_spec=False,
audio_config=None,
):
- super(ResNetSpeakerEncoder, self).__init__()
+ super().__init__()
self.encoder_type = encoder_type
self.input_dim = input_dim
diff --git a/TTS/encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py
index 495b4def5a..54ab37a52f 100644
--- a/TTS/encoder/utils/generic_utils.py
+++ b/TTS/encoder/utils/generic_utils.py
@@ -6,13 +6,14 @@
import numpy as np
from scipy import signal
+from TTS.encoder.models.base_encoder import BaseEncoder
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
logger = logging.getLogger(__name__)
-class AugmentWAV(object):
+class AugmentWAV:
def __init__(self, ap, augmentation_config):
self.ap = ap
self.use_additive_noise = False
@@ -120,7 +121,7 @@ def apply_one(self, audio):
return self.additive_noise(noise_type, audio)
-def setup_encoder_model(config: "Coqpit"):
+def setup_encoder_model(config: "Coqpit") -> BaseEncoder:
if config.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(
config.model_params["input_dim"],
@@ -138,4 +139,7 @@ def setup_encoder_model(config: "Coqpit"):
use_torch_spec=config.model_params.get("use_torch_spec", False),
audio_config=config.audio,
)
+ else:
+ msg = f"Model not supported: {config.model_params['model_name']}"
+ raise ValueError(msg)
return model
diff --git a/TTS/encoder/utils/prepare_voxceleb.py b/TTS/encoder/utils/prepare_voxceleb.py
index 37619ed0f8..8d50ffd5f5 100644
--- a/TTS/encoder/utils/prepare_voxceleb.py
+++ b/TTS/encoder/utils/prepare_voxceleb.py
@@ -1,4 +1,3 @@
-# coding=utf-8
# Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo
# All rights reserved.
#
@@ -17,7 +16,7 @@
# Only support eager mode and TF>=2.0.0
# pylint: disable=no-member, invalid-name, relative-beyond-top-level
# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes
-""" voxceleb 1 & 2 """
+"""voxceleb 1 & 2"""
import csv
import hashlib
@@ -81,19 +80,19 @@ def download_and_extract(directory, subset, urls):
zip_filepath = os.path.join(directory, url.split("/")[-1])
if os.path.exists(zip_filepath):
continue
- logger.info("Downloading %s to %s" % (url, zip_filepath))
+ logger.info("Downloading %s to %s", url, zip_filepath)
subprocess.call(
- "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath),
+ "wget {} --user {} --password {} -O {}".format(url, USER["user"], USER["password"], zip_filepath),
shell=True,
)
statinfo = os.stat(zip_filepath)
- logger.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
+ logger.info("Successfully downloaded %s, size(bytes): %d", url, statinfo.st_size)
# concatenate all parts into zip files
if ".zip" not in zip_filepath:
zip_filepath = "_".join(zip_filepath.split("_")[:-1])
- subprocess.call("cat %s* > %s.zip" % (zip_filepath, zip_filepath), shell=True)
+ subprocess.call(f"cat {zip_filepath}* > {zip_filepath}.zip", shell=True)
zip_filepath += ".zip"
extract_path = zip_filepath.strip(".zip")
@@ -101,12 +100,12 @@ def download_and_extract(directory, subset, urls):
with open(zip_filepath, "rb") as f_zip:
md5 = hashlib.md5(f_zip.read()).hexdigest()
if md5 != MD5SUM[subset]:
- raise ValueError("md5sum of %s mismatch" % zip_filepath)
+ raise ValueError(f"md5sum of {zip_filepath} mismatch")
with zipfile.ZipFile(zip_filepath, "r") as zfile:
zfile.extractall(directory)
extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename)
- subprocess.call("mv %s %s" % (extract_path_ori, extract_path), shell=True)
+ subprocess.call(f"mv {extract_path_ori} {extract_path}", shell=True)
finally:
# os.remove(zip_filepath)
pass
@@ -122,9 +121,9 @@ def exec_cmd(cmd):
try:
retcode = subprocess.call(cmd, shell=True)
if retcode < 0:
- logger.info(f"Child was terminated by signal {retcode}")
+ logger.info("Child was terminated by signal %d", retcode)
except OSError as e:
- logger.info(f"Execution failed: {e}")
+ logger.info("Execution failed: %s", e)
retcode = -999
return retcode
@@ -138,10 +137,10 @@ def decode_aac_with_ffmpeg(aac_file, wav_file):
bool, True if success.
"""
cmd = f"ffmpeg -i {aac_file} {wav_file}"
- logger.info(f"Decoding aac file using command line: {cmd}")
+ logger.info("Decoding aac file using command line: %s", cmd)
ret = exec_cmd(cmd)
if ret != 0:
- logger.error(f"Failed to decode aac file with retcode {ret}")
+ logger.error("Failed to decode aac file with retcode %s", ret)
logger.error("Please check your ffmpeg installation.")
return False
return True
@@ -156,7 +155,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv
"""
- logger.info("Preprocessing audio and label for subset %s" % subset)
+ logger.info("Preprocessing audio and label for subset %s", subset)
source_dir = os.path.join(input_dir, subset)
files = []
@@ -194,7 +193,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
for wav_file in files:
writer.writerow(wav_file)
- logger.info("Successfully generated csv file {}".format(csv_file_path))
+ logger.info("Successfully generated csv file %s", csv_file_path)
def processor(directory, subset, force_process):
diff --git a/TTS/encoder/utils/training.py b/TTS/encoder/utils/training.py
deleted file mode 100644
index 48629c7a57..0000000000
--- a/TTS/encoder/utils/training.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import os
-from dataclasses import dataclass, field
-
-from coqpit import Coqpit
-from trainer import TrainerArgs
-from trainer.generic_utils import get_experiment_folder_path, get_git_branch
-from trainer.io import copy_model_files, get_last_checkpoint
-from trainer.logging import logger_factory
-from trainer.logging.console_logger import ConsoleLogger
-
-from TTS.config import load_config, register_config
-from TTS.tts.utils.text.characters import parse_symbols
-
-
-@dataclass
-class TrainArgs(TrainerArgs):
- config_path: str = field(default=None, metadata={"help": "Path to the config file."})
-
-
-def getarguments():
- train_config = TrainArgs()
- parser = train_config.init_argparse(arg_prefix="")
- return parser
-
-
-def process_args(args, config=None):
- """Process parsed comand line arguments and initialize the config if not provided.
- Args:
- args (argparse.Namespace or dict like): Parsed input arguments.
- config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
- Returns:
- c (Coqpit): Config paramaters.
- out_path (str): Path to save models and logging.
- audio_path (str): Path to save generated test audios.
- c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
- logging to the console.
- dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
- TODO:
- - Interactive config definition.
- """
- if isinstance(args, tuple):
- args, coqpit_overrides = args
- if args.continue_path:
- # continue a previous training from its output folder
- experiment_path = args.continue_path
- args.config_path = os.path.join(args.continue_path, "config.json")
- args.restore_path, best_model = get_last_checkpoint(args.continue_path)
- if not args.best_path:
- args.best_path = best_model
- # init config if not already defined
- if config is None:
- if args.config_path:
- # init from a file
- config = load_config(args.config_path)
- else:
- # init from console args
- from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
-
- config_base = BaseTrainingConfig()
- config_base.parse_known_args(coqpit_overrides)
- config = register_config(config_base.model)()
- # override values from command-line args
- config.parse_known_args(coqpit_overrides, relaxed_parser=True)
- experiment_path = args.continue_path
- if not experiment_path:
- experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
- audio_path = os.path.join(experiment_path, "test_audios")
- config.output_log_path = experiment_path
- # setup rank 0 process in distributed training
- dashboard_logger = None
- if args.rank == 0:
- new_fields = {}
- if args.restore_path:
- new_fields["restore_path"] = args.restore_path
- new_fields["github_branch"] = get_git_branch()
- # if model characters are not set in the config file
- # save the default set to the config file for future
- # compatibility.
- if config.has("characters") and config.characters is None:
- used_characters = parse_symbols()
- new_fields["characters"] = used_characters
- copy_model_files(config, experiment_path, new_fields)
- dashboard_logger = logger_factory(config, experiment_path)
- c_logger = ConsoleLogger()
- return config, experiment_path, audio_path, c_logger, dashboard_logger
-
-
-def init_arguments():
- train_config = TrainArgs()
- parser = train_config.init_argparse(arg_prefix="")
- return parser
-
-
-def init_training(config: Coqpit = None):
- """Initialization of a training run."""
- parser = init_arguments()
- args = parser.parse_known_args()
- config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
- return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger
diff --git a/TTS/model.py b/TTS/model.py
index e024ad1a44..39faa7f690 100644
--- a/TTS/model.py
+++ b/TTS/model.py
@@ -1,6 +1,6 @@
import os
from abc import abstractmethod
-from typing import Any, Union
+from typing import Any
import torch
from coqpit import Coqpit
@@ -48,7 +48,7 @@ def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict
def load_checkpoint(
self,
config: Coqpit,
- checkpoint_path: Union[str, os.PathLike[Any]],
+ checkpoint_path: str | os.PathLike[Any],
eval: bool = False,
strict: bool = True,
cache: bool = False,
diff --git a/TTS/server/README.md b/TTS/server/README.md
index ae8e38a4e3..232b8618d8 100644
--- a/TTS/server/README.md
+++ b/TTS/server/README.md
@@ -1,21 +1,36 @@
-# :frog: TTS demo server
+# :frog: TTS Demo Server
Before you use the server, make sure you
-[install](https://github.com/idiap/coqui-ai-TTS/tree/dev#install-tts)) :frog: TTS
+[install](https://github.com/idiap/coqui-ai-TTS/tree/dev#install-tts) :frog: TTS
properly and install the additional dependencies with `pip install
coqui-tts[server]`. Then, you can follow the steps below.
-**Note:** If you install :frog:TTS using ```pip```, you can also use the ```tts-server``` end point on the terminal.
+**Note:** If you install :frog:TTS using ```pip```, you can also use the ```tts-server``` endpoint on the terminal instead of the `python TTS/server/server.py` arguments.
-Examples runs:
+## Example commands
-List officially released models.
-```python TTS/server/server.py --list_models ```
+List officially released models:
+```bash
+python TTS/server/server.py --list_models # or
+tts-server --list_models
+```
-Run the server with the official models.
-```python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan```
+Run the server with the official models:
+```bash
+python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA \
+ --vocoder_name vocoder_models/en/ljspeech/multiband-melgan
+```
-Run the server with the official models on a GPU.
-```CUDA_VISIBLE_DEVICES="0" python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan --use_cuda```
+Run the server with the official models on a GPU:
+```bash
+CUDA_VISIBLE_DEVICES="0" python TTS/server/server.py \
+ --model_name tts_models/en/ljspeech/tacotron2-DCA
+ --vocoder_name vocoder_models/en/ljspeech/multiband-melgan --use_cuda
+```
-Run the server with a custom models.
-```python TTS/server/server.py --tts_checkpoint /path/to/tts/model.pth --tts_config /path/to/tts/config.json --vocoder_checkpoint /path/to/vocoder/model.pth --vocoder_config /path/to/vocoder/config.json```
+Run the server with a custom models:
+```bash
+python TTS/server/server.py --tts_checkpoint /path/to/tts/model.pth \
+ --tts_config /path/to/tts/config.json \
+ --vocoder_checkpoint /path/to/vocoder/model.pth \
+ --vocoder_config /path/to/vocoder/config.json
+```
diff --git a/TTS/server/server.py b/TTS/server/server.py
index cb4ed4d9b2..500c706c4e 100644
--- a/TTS/server/server.py
+++ b/TTS/server/server.py
@@ -9,7 +9,6 @@
import os
import sys
from threading import Lock
-from typing import Union
from urllib.parse import parse_qs
try:
@@ -105,10 +104,11 @@ def create_argparser() -> argparse.ArgumentParser:
# TODO: set this from SpeakerManager
use_gst = api.synthesizer.tts_config.get("use_gst", False)
+supports_cloning = api.synthesizer.tts_config.get("model", "") in ["xtts", "bark"]
app = Flask(__name__)
-def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
+def style_wav_uri_to_dict(style_wav: str) -> str | dict:
"""Transform an uri style_wav, in either a string (path to wav file to be use for style transfer)
or a dict (gst tokens/values to be use for styling)
@@ -137,6 +137,7 @@ def index():
speaker_ids=api.speakers,
language_ids=api.languages,
use_gst=use_gst,
+ supports_cloning=supports_cloning,
)
@@ -164,6 +165,8 @@ def tts():
speaker_idx = (
request.headers.get("speaker-id") or request.values.get("speaker_id", "") if api.is_multi_speaker else None
)
+ if speaker_idx == "":
+ speaker_idx = None
language_idx = (
request.headers.get("language-id") or request.values.get("language_id", "")
if api.is_multi_lingual
@@ -171,11 +174,12 @@ def tts():
)
style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "")
style_wav = style_wav_uri_to_dict(style_wav)
+ speaker_wav = request.headers.get("speaker-wav") or request.values.get("speaker_wav", "")
logger.info("Model input: %s", text)
logger.info("Speaker idx: %s", speaker_idx)
logger.info("Language idx: %s", language_idx)
- wavs = api.tts(text, speaker=speaker_idx, language=language_idx, style_wav=style_wav)
+ wavs = api.tts(text, speaker=speaker_idx, language=language_idx, style_wav=style_wav, speaker_wav=speaker_wav)
out = io.BytesIO()
api.synthesizer.save_wav(wavs, out)
return send_file(out, mimetype="audio/wav")
diff --git a/TTS/server/templates/index.html b/TTS/server/templates/index.html
index 6bfd5ae2cb..95d7076394 100644
--- a/TTS/server/templates/index.html
+++ b/TTS/server/templates/index.html
@@ -66,7 +66,12 @@
{%if use_gst%}
+ type="text" name="style_wav">
+ {%endif%}
+
+ {%if supports_cloning%}
+ Reference audio:
+
{%endif%}
@@ -114,14 +119,18 @@
q('#text').focus()
function do_tts(e) {
const text = q('#text').value
- const speaker_id = getTextValue('#speaker_id')
const style_wav = getTextValue('#style_wav')
+ const speaker_wav = getTextValue('#speaker_wav')
+ let speaker_id = getTextValue('#speaker_id')
+ if (speaker_wav !== '') {
+ speaker_id = ''
+ }
const language_id = getTextValue('#language_id')
if (text) {
q('#message').textContent = 'Synthesizing...'
q('#speak-button').disabled = true
q('#audio').hidden = true
- synthesize(text, speaker_id, style_wav, language_id)
+ synthesize(text, speaker_id, style_wav, speaker_wav, language_id)
}
e.preventDefault()
return false
@@ -132,8 +141,8 @@
do_tts(e)
}
})
- function synthesize(text, speaker_id = "", style_wav = "", language_id = "") {
- fetch(`/api/tts?text=${encodeURIComponent(text)}&speaker_id=${encodeURIComponent(speaker_id)}&style_wav=${encodeURIComponent(style_wav)}&language_id=${encodeURIComponent(language_id)}`, { cache: 'no-cache' })
+ function synthesize(text, speaker_id = "", style_wav = "", speaker_wav = "", language_id = "") {
+ fetch(`/api/tts?text=${encodeURIComponent(text)}&speaker_id=${encodeURIComponent(speaker_id)}&style_wav=${encodeURIComponent(style_wav)}&speaker_wav=${encodeURIComponent(speaker_wav)}&language_id=${encodeURIComponent(language_id)}`, { cache: 'no-cache' })
.then(function (res) {
if (!res.ok) throw Error(res.statusText)
return res.blob()
diff --git a/TTS/tts/configs/align_tts_config.py b/TTS/tts/configs/align_tts_config.py
index 317a01af53..784819eee3 100644
--- a/TTS/tts/configs/align_tts_config.py
+++ b/TTS/tts/configs/align_tts_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.align_tts import AlignTTSArgs
@@ -70,7 +69,7 @@ class AlignTTSConfig(BaseTTSConfig):
model: str = "align_tts"
# model specific params
model_args: AlignTTSArgs = field(default_factory=AlignTTSArgs)
- phase_start_steps: List[int] = None
+ phase_start_steps: list[int] | None = None
ssim_alpha: float = 1.0
spec_loss_alpha: float = 1.0
@@ -80,13 +79,13 @@ class AlignTTSConfig(BaseTTSConfig):
# multi-speaker settings
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
- d_vector_file: str = False
+ d_vector_file: str | None = None
# optimizer parameters
optimizer: str = "Adam"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
- lr_scheduler: str = None
- lr_scheduler_params: dict = None
+ lr_scheduler: str | None = None
+ lr_scheduler_params: dict | None = None
lr: float = 1e-4
grad_clip: float = 5.0
@@ -96,7 +95,7 @@ class AlignTTSConfig(BaseTTSConfig):
r: int = 1
# testing
- test_sentences: List[str] = field(
+ test_sentences: list[str] | list[list[str]] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
diff --git a/TTS/tts/configs/bark_config.py b/TTS/tts/configs/bark_config.py
index b846febe85..61d67b987a 100644
--- a/TTS/tts/configs/bark_config.py
+++ b/TTS/tts/configs/bark_config.py
@@ -1,6 +1,5 @@
import os
from dataclasses import dataclass, field
-from typing import Dict
from trainer.io import get_user_data_dir
@@ -70,9 +69,9 @@ class BarkConfig(BaseTTSConfig):
COARSE_INFER_TOKEN: int = 12_050
REMOTE_BASE_URL = "https://huggingface.co/erogol/bark/tree/main/"
- REMOTE_MODEL_PATHS: Dict = None
- LOCAL_MODEL_PATHS: Dict = None
- SMALL_REMOTE_MODEL_PATHS: Dict = None
+ REMOTE_MODEL_PATHS: dict = None
+ LOCAL_MODEL_PATHS: dict = None
+ SMALL_REMOTE_MODEL_PATHS: dict = None
CACHE_DIR: str = str(get_user_data_dir("tts/suno/bark_v0"))
DEF_SPEAKER_DIR: str = str(get_user_data_dir("tts/bark_v0/speakers"))
diff --git a/TTS/tts/configs/delightful_tts_config.py b/TTS/tts/configs/delightful_tts_config.py
index 805d995369..fc9a76f613 100644
--- a/TTS/tts/configs/delightful_tts_config.py
+++ b/TTS/tts/configs/delightful_tts_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.delightful_tts import DelightfulTtsArgs, DelightfulTtsAudioConfig, VocoderConfig
@@ -73,7 +72,7 @@ class DelightfulTTSConfig(BaseTTSConfig):
# optimizer
steps_to_start_discriminator: int = 200000
- grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
+ grad_clip: list[float] = field(default_factory=lambda: [1000, 1000])
lr_gen: float = 0.0002
lr_disc: float = 0.0002
lr_scheduler_gen: str = "ExponentialLR"
@@ -140,7 +139,7 @@ class DelightfulTTSConfig(BaseTTSConfig):
d_vector_dim: int = None
# testing
- test_sentences: List[List[str]] = field(
+ test_sentences: list[str] | list[list[str]] = field(
default_factory=lambda: [
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
["Be a voice, not an echo."],
diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py
index d086d26564..1342856668 100644
--- a/TTS/tts/configs/fast_pitch_config.py
+++ b/TTS/tts/configs/fast_pitch_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.forward_tts import ForwardTTSArgs
@@ -117,10 +116,10 @@ class FastPitchConfig(BaseTTSConfig):
# multi-speaker settings
num_speakers: int = 0
- speakers_file: str = None
+ speakers_file: str | None = None
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
- d_vector_file: str = False
+ d_vector_file: str | None = None
d_vector_dim: int = 0
# optimizer parameters
@@ -150,10 +149,10 @@ class FastPitchConfig(BaseTTSConfig):
# dataset configs
compute_f0: bool = True
- f0_cache_path: str = None
+ f0_cache_path: str | None = None
# testing
- test_sentences: List[str] = field(
+ test_sentences: list[str] | list[list[str]] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
diff --git a/TTS/tts/configs/fast_speech_config.py b/TTS/tts/configs/fast_speech_config.py
index af6c2db6fa..408dbab196 100644
--- a/TTS/tts/configs/fast_speech_config.py
+++ b/TTS/tts/configs/fast_speech_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.forward_tts import ForwardTTSArgs
@@ -111,10 +110,10 @@ class FastSpeechConfig(BaseTTSConfig):
# multi-speaker settings
num_speakers: int = 0
- speakers_file: str = None
+ speakers_file: str | None = None
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
- d_vector_file: str = False
+ d_vector_file: str | None = None
d_vector_dim: int = 0
# optimizer parameters
@@ -144,10 +143,10 @@ class FastSpeechConfig(BaseTTSConfig):
# dataset configs
compute_f0: bool = False
- f0_cache_path: str = None
+ f0_cache_path: str | None = None
# testing
- test_sentences: List[str] = field(
+ test_sentences: list[str] | list[list[str]] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
diff --git a/TTS/tts/configs/fastspeech2_config.py b/TTS/tts/configs/fastspeech2_config.py
index d179617fb0..44bdefad0d 100644
--- a/TTS/tts/configs/fastspeech2_config.py
+++ b/TTS/tts/configs/fastspeech2_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.forward_tts import ForwardTTSArgs
@@ -127,10 +126,10 @@ class Fastspeech2Config(BaseTTSConfig):
# multi-speaker settings
num_speakers: int = 0
- speakers_file: str = None
+ speakers_file: str | None = None
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
- d_vector_file: str = False
+ d_vector_file: str | None = None
d_vector_dim: int = 0
# optimizer parameters
@@ -161,14 +160,14 @@ class Fastspeech2Config(BaseTTSConfig):
# dataset configs
compute_f0: bool = True
- f0_cache_path: str = None
+ f0_cache_path: str | None = None
# dataset configs
compute_energy: bool = True
- energy_cache_path: str = None
+ energy_cache_path: str | None = None
# testing
- test_sentences: List[str] = field(
+ test_sentences: list[str] | list[list[str]] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
diff --git a/TTS/tts/configs/glow_tts_config.py b/TTS/tts/configs/glow_tts_config.py
index f42f3e5a51..c99e920b9d 100644
--- a/TTS/tts/configs/glow_tts_config.py
+++ b/TTS/tts/configs/glow_tts_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
@@ -101,7 +100,7 @@ class GlowTTSConfig(BaseTTSConfig):
model: str = "glow_tts"
# model params
- num_chars: int = None
+ num_chars: int | None = None
encoder_type: str = "rel_pos_transformer"
encoder_params: dict = field(
default_factory=lambda: {
@@ -147,15 +146,15 @@ class GlowTTSConfig(BaseTTSConfig):
data_dep_init_steps: int = 10
# inference params
- style_wav_for_test: str = None
+ style_wav_for_test: str | None = None
inference_noise_scale: float = 0.0
length_scale: float = 1.0
# multi-speaker settings
use_speaker_embedding: bool = False
- speakers_file: str = None
+ speakers_file: str | None = None
use_d_vector_file: bool = False
- d_vector_file: str = False
+ d_vector_file: str | None = None
# optimizer parameters
optimizer: str = "RAdam"
@@ -171,7 +170,7 @@ class GlowTTSConfig(BaseTTSConfig):
r: int = 1 # DO NOT CHANGE - TODO: make this immutable once coqpit implements it.
# testing
- test_sentences: List[str] = field(
+ test_sentences: list[str] | list[list[str]] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
diff --git a/TTS/tts/configs/neuralhmm_tts_config.py b/TTS/tts/configs/neuralhmm_tts_config.py
index 50f72847ed..108f2022d4 100644
--- a/TTS/tts/configs/neuralhmm_tts_config.py
+++ b/TTS/tts/configs/neuralhmm_tts_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
@@ -126,7 +125,7 @@ class NeuralhmmTTSConfig(BaseTTSConfig):
memory_rnn_dim: int = 1024
## Outputnet parameters
- outputnet_size: List[int] = field(default_factory=lambda: [1024])
+ outputnet_size: list[int] = field(default_factory=lambda: [1024])
flat_start_params: dict = field(default_factory=lambda: {"mean": 0.0, "std": 1.0, "transition_p": 0.14})
std_floor: float = 0.001
@@ -143,7 +142,7 @@ class NeuralhmmTTSConfig(BaseTTSConfig):
min_audio_len: int = 512
# testing
- test_sentences: List[str] = field(
+ test_sentences: list[str] | list[list[str]] = field(
default_factory=lambda: [
"Be a voice, not an echo.",
]
@@ -162,9 +161,9 @@ def check_values(self):
AssertionError: transition probability is not between 0 and 1
"""
assert self.ar_order > 0, "AR order must be greater than 0 it is an autoregressive model."
- assert (
- len(self.outputnet_size) >= 1
- ), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}"
- assert (
- 0 < self.flat_start_params["transition_p"] < 1
- ), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}"
+ assert len(self.outputnet_size) >= 1, (
+ f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}"
+ )
+ assert 0 < self.flat_start_params["transition_p"] < 1, (
+ f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}"
+ )
diff --git a/TTS/tts/configs/overflow_config.py b/TTS/tts/configs/overflow_config.py
index dc3e5548b8..9e96aaa441 100644
--- a/TTS/tts/configs/overflow_config.py
+++ b/TTS/tts/configs/overflow_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
@@ -145,7 +144,7 @@ class OverflowConfig(BaseTTSConfig): # The classname has to be camel case
memory_rnn_dim: int = 1024
## Outputnet parameters
- outputnet_size: List[int] = field(default_factory=lambda: [1024])
+ outputnet_size: list[int] = field(default_factory=lambda: [1024])
flat_start_params: dict = field(default_factory=lambda: {"mean": 0.0, "std": 1.0, "transition_p": 0.14})
std_floor: float = 0.01
@@ -174,7 +173,7 @@ class OverflowConfig(BaseTTSConfig): # The classname has to be camel case
min_audio_len: int = 512
# testing
- test_sentences: List[str] = field(
+ test_sentences: list[str] | list[list[str]] = field(
default_factory=lambda: [
"Be a voice, not an echo.",
]
@@ -193,9 +192,9 @@ def check_values(self):
AssertionError: transition probability is not between 0 and 1
"""
assert self.ar_order > 0, "AR order must be greater than 0 it is an autoregressive model."
- assert (
- len(self.outputnet_size) >= 1
- ), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}"
- assert (
- 0 < self.flat_start_params["transition_p"] < 1
- ), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}"
+ assert len(self.outputnet_size) >= 1, (
+ f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}"
+ )
+ assert 0 < self.flat_start_params["transition_p"] < 1, (
+ f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}"
+ )
diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py
index bf17322c19..c62f68306d 100644
--- a/TTS/tts/configs/shared_configs.py
+++ b/TTS/tts/configs/shared_configs.py
@@ -1,5 +1,4 @@
from dataclasses import asdict, dataclass, field
-from typing import Dict, List
from coqpit import Coqpit, check_argument
@@ -138,7 +137,7 @@ class CharactersConfig(Coqpit):
characters_class: str = None
# using BaseVocabulary
- vocab_dict: Dict = None
+ vocab_dict: dict = None
# using on BaseCharacters
pad: str = None
@@ -323,7 +322,7 @@ class BaseTTSConfig(BaseTrainingConfig):
shuffle: bool = False
drop_last: bool = False
# dataset
- datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
+ datasets: list[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer
optimizer: str = "radam"
optimizer_params: dict = None
@@ -331,7 +330,7 @@ class BaseTTSConfig(BaseTrainingConfig):
lr_scheduler: str = None
lr_scheduler_params: dict = field(default_factory=lambda: {})
# testing
- test_sentences: List[str] = field(default_factory=lambda: [])
+ test_sentences: list[str] | list[list[str]] = field(default_factory=lambda: [])
# evaluation
eval_split_max_size: int = None
eval_split_size: float = 0.01
diff --git a/TTS/tts/configs/speedy_speech_config.py b/TTS/tts/configs/speedy_speech_config.py
index bf8517dfc4..b37ba174bf 100644
--- a/TTS/tts/configs/speedy_speech_config.py
+++ b/TTS/tts/configs/speedy_speech_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.forward_tts import ForwardTTSArgs
@@ -129,10 +128,10 @@ class SpeedySpeechConfig(BaseTTSConfig):
# multi-speaker settings
num_speakers: int = 0
- speakers_file: str = None
+ speakers_file: str | None = None
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
- d_vector_file: str = False
+ d_vector_file: str | None = None
d_vector_dim: int = 0
# optimizer parameters
@@ -161,10 +160,10 @@ class SpeedySpeechConfig(BaseTTSConfig):
# dataset configs
compute_f0: bool = False
- f0_cache_path: str = None
+ f0_cache_path: str | None = None
# testing
- test_sentences: List[str] = field(
+ test_sentences: list[str] | list[list[str]] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
diff --git a/TTS/tts/configs/tacotron_config.py b/TTS/tts/configs/tacotron_config.py
index 350b5ea996..caa118815a 100644
--- a/TTS/tts/configs/tacotron_config.py
+++ b/TTS/tts/configs/tacotron_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig, CapacitronVAEConfig, GSTConfig
@@ -154,7 +153,7 @@ class TacotronConfig(BaseTTSConfig):
num_speakers: int = 1
num_chars: int = 0
r: int = 2
- gradual_training: List[List[int]] = None
+ gradual_training: list[list[int]] = None
memory_size: int = -1
prenet_type: str = "original"
prenet_dropout: bool = True
@@ -170,7 +169,7 @@ class TacotronConfig(BaseTTSConfig):
# attention layers
attention_type: str = "original"
- attention_heads: int = None
+ attention_heads: int | None = None
attention_norm: str = "sigmoid"
attention_win: bool = False
windowing: bool = False
@@ -189,8 +188,8 @@ class TacotronConfig(BaseTTSConfig):
use_speaker_embedding: bool = False
speaker_embedding_dim: int = 512
use_d_vector_file: bool = False
- d_vector_file: str = False
- d_vector_dim: int = None
+ d_vector_file: str | None = None
+ d_vector_dim: int | None = None
# optimizer parameters
optimizer: str = "RAdam"
@@ -212,7 +211,7 @@ class TacotronConfig(BaseTTSConfig):
ga_alpha: float = 5.0
# testing
- test_sentences: List[str] = field(
+ test_sentences: list[str] | list[list[str]] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
@@ -224,12 +223,12 @@ class TacotronConfig(BaseTTSConfig):
def check_values(self):
if self.gradual_training:
- assert (
- self.gradual_training[0][1] == self.r
- ), f"[!] the first scheduled gradual training `r` must be equal to the model's `r` value. {self.gradual_training[0][1]} vs {self.r}"
+ assert self.gradual_training[0][1] == self.r, (
+ f"[!] the first scheduled gradual training `r` must be equal to the model's `r` value. {self.gradual_training[0][1]} vs {self.r}"
+ )
if self.model == "tacotron" and self.audio is not None:
- assert self.out_channels == (
- self.audio.fft_size // 2 + 1
- ), f"{self.out_channels} vs {self.audio.fft_size // 2 + 1}"
+ assert self.out_channels == (self.audio.fft_size // 2 + 1), (
+ f"{self.out_channels} vs {self.audio.fft_size // 2 + 1}"
+ )
if self.model == "tacotron2" and self.audio is not None:
assert self.out_channels == self.audio.num_mels
diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py
index 2d0242bf13..9ad720da30 100644
--- a/TTS/tts/configs/vits_config.py
+++ b/TTS/tts/configs/vits_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.vits import VitsArgs, VitsAudioConfig
@@ -112,7 +111,7 @@ class VitsConfig(BaseTTSConfig):
audio: VitsAudioConfig = field(default_factory=VitsAudioConfig)
# optimizer
- grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
+ grad_clip: list[float] = field(default_factory=lambda: [1000, 1000])
lr_gen: float = 0.0002
lr_disc: float = 0.0002
lr_scheduler_gen: str = "ExponentialLR"
@@ -146,7 +145,7 @@ class VitsConfig(BaseTTSConfig):
add_blank: bool = True
# testing
- test_sentences: List[List] = field(
+ test_sentences: list[str] | list[list[str]] = field(
default_factory=lambda: [
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
["Be a voice, not an echo."],
@@ -167,7 +166,7 @@ class VitsConfig(BaseTTSConfig):
# use d-vectors
use_d_vector_file: bool = False
- d_vector_file: List[str] = None
+ d_vector_file: list[str] = None
d_vector_dim: int = None
def __post_init__(self):
diff --git a/TTS/tts/configs/xtts_config.py b/TTS/tts/configs/xtts_config.py
index bbf048e1ab..da6cc6edc6 100644
--- a/TTS/tts/configs/xtts_config.py
+++ b/TTS/tts/configs/xtts_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
@@ -70,7 +69,7 @@ class XttsConfig(BaseTTSConfig):
model_args: XttsArgs = field(default_factory=XttsArgs)
audio: XttsAudioConfig = field(default_factory=XttsAudioConfig)
model_dir: str = None
- languages: List[str] = field(
+ languages: list[str] = field(
default_factory=lambda: [
"en",
"es",
diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py
index d1a37da4c1..d83abce00a 100644
--- a/TTS/tts/datasets/__init__.py
+++ b/TTS/tts/datasets/__init__.py
@@ -2,8 +2,8 @@
import os
import sys
from collections import Counter
+from collections.abc import Callable
from pathlib import Path
-from typing import Callable, Dict, List, Tuple, Union
import numpy as np
@@ -17,7 +17,7 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
Args:
- items (List[List]):
+ items (list[list]):
A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.
eval_split_max_size (int):
@@ -37,10 +37,8 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
else:
eval_split_size = int(len(items) * eval_split_size)
- assert (
- eval_split_size > 0
- ), " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format(
- 1 / len(items)
+ assert eval_split_size > 0, (
+ f" [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {1 / len(items)}"
)
np.random.seed(0)
np.random.shuffle(items)
@@ -71,18 +69,18 @@ def add_extra_keys(metadata, language, dataset_name):
def load_tts_samples(
- datasets: Union[List[Dict], Dict],
+ datasets: list[dict] | dict,
eval_split=True,
formatter: Callable = None,
eval_split_max_size=None,
eval_split_size=0.01,
-) -> Tuple[List[List], List[List]]:
- """Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided.
+) -> tuple[list[list], list[list]]:
+ """Parse the dataset from the datasets config, load the samples as a list and load the attention alignments if provided.
If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based
on the dataset name.
Args:
- datasets (List[Dict], Dict): A list of datasets or a single dataset dictionary. If multiple datasets are
+ datasets (list[dict], dict): A list of datasets or a single dataset dictionary. If multiple datasets are
in the list, they are all merged.
eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate
@@ -101,7 +99,7 @@ def load_tts_samples(
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
Returns:
- Tuple[List[List], List[List]: training and evaluation splits of the dataset.
+ tuple[list[list], list[list]: training and evaluation splits of the dataset.
"""
meta_data_train_all = []
meta_data_eval_all = [] if eval_split else None
@@ -153,7 +151,7 @@ def load_tts_samples(
def load_attention_mask_meta_data(metafile_path):
"""Load meta data file created by compute_attention_masks.py"""
- with open(metafile_path, "r", encoding="utf-8") as f:
+ with open(metafile_path, encoding="utf-8") as f:
lines = f.readlines()
meta_data = []
diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py
index 5f629f32a9..6f21dcd1e0 100644
--- a/TTS/tts/datasets/dataset.py
+++ b/TTS/tts/datasets/dataset.py
@@ -3,7 +3,7 @@
import logging
import os
import random
-from typing import Any, Optional, Union
+from typing import Any
import numpy as np
import numpy.typing as npt
@@ -47,7 +47,7 @@ def string2filename(string: str) -> str:
return base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore")
-def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
+def get_audio_size(audiopath: str | os.PathLike[Any]) -> int:
"""Return the number of samples in the audio file."""
if not isinstance(audiopath, str):
audiopath = str(audiopath)
@@ -63,7 +63,7 @@ def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
raise RuntimeError(msg) from e
-def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: Optional[dict] = None):
+def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict | None = None):
"""Create inverse frequency weights for balancing the dataset.
Use `multi_dict` to scale relative weights."""
@@ -94,23 +94,23 @@ def __init__(
outputs_per_step: int = 1,
compute_linear_spec: bool = False,
ap: AudioProcessor = None,
- samples: Optional[list[dict]] = None,
+ samples: list[dict] | None = None,
tokenizer: "TTSTokenizer" = None,
compute_f0: bool = False,
compute_energy: bool = False,
- f0_cache_path: Optional[str] = None,
- energy_cache_path: Optional[str] = None,
+ f0_cache_path: str | None = None,
+ energy_cache_path: str | None = None,
return_wav: bool = False,
batch_group_size: int = 0,
min_text_len: int = 0,
max_text_len: int = float("inf"),
min_audio_len: int = 0,
max_audio_len: int = float("inf"),
- phoneme_cache_path: Optional[str] = None,
+ phoneme_cache_path: str | None = None,
precompute_num_workers: int = 0,
- speaker_id_mapping: Optional[dict] = None,
- d_vector_mapping: Optional[dict] = None,
- language_id_mapping: Optional[dict] = None,
+ speaker_id_mapping: dict | None = None,
+ d_vector_mapping: dict | None = None,
+ language_id_mapping: dict | None = None,
use_noise_augment: bool = False,
start_by_longest: bool = False,
) -> None:
@@ -231,7 +231,7 @@ def lengths(self) -> list[int]:
try:
audio_len = get_audio_size(wav_file)
except RuntimeError:
- logger.warning(f"Failed to compute length for {item['audio_file']}")
+ logger.warning("Failed to compute length for %s", item["audio_file"])
audio_len = 0
lens.append(audio_len)
return lens
@@ -352,7 +352,7 @@ def _compute_lengths(samples):
try:
audio_length = get_audio_size(item["audio_file"])
except RuntimeError:
- logger.warning(f"Failed to compute length, skipping {item['audio_file']}")
+ logger.warning("Failed to compute length, skipping %s", item["audio_file"])
continue
text_lenght = len(item["text"])
item["audio_length"] = audio_length
@@ -437,14 +437,14 @@ def preprocess_samples(self) -> None:
self.samples = samples
logger.info("Preprocessing samples")
- logger.info(f"Max text length: {np.max(text_lengths)}")
- logger.info(f"Min text length: {np.min(text_lengths)}")
- logger.info(f"Avg text length: {np.mean(text_lengths)}")
- logger.info(f"Max audio length: {np.max(audio_lengths)}")
- logger.info(f"Min audio length: {np.min(audio_lengths)}")
- logger.info(f"Avg audio length: {np.mean(audio_lengths)}")
+ logger.info("Max text length: %d", np.max(text_lengths))
+ logger.info("Min text length: %d", np.min(text_lengths))
+ logger.info("Avg text length: %.2f", np.mean(text_lengths))
+ logger.info("Max audio length: %.2f", np.max(audio_lengths))
+ logger.info("Min audio length: %.2f", np.min(audio_lengths))
+ logger.info("Avg audio length: %.2f", np.mean(audio_lengths))
logger.info("Num. instances discarded samples: %d", len(ignore_idx))
- logger.info(f"Batch group size: {self.batch_group_size}.")
+ logger.info("Batch group size: %d", self.batch_group_size)
@staticmethod
def _sort_batch(batch, text_lengths):
@@ -640,7 +640,7 @@ class PhonemeDataset(Dataset):
def __init__(
self,
- samples: Union[list[dict], list[list]],
+ samples: list[dict] | list[list],
tokenizer: "TTSTokenizer",
cache_path: str,
precompute_num_workers: int = 0,
@@ -744,10 +744,10 @@ class F0Dataset:
def __init__(
self,
- samples: Union[list[list], list[dict]],
+ samples: list[list] | list[dict],
ap: "AudioProcessor",
audio_config=None, # pylint: disable=unused-argument
- cache_path: Optional[str] = None,
+ cache_path: str | None = None,
precompute_num_workers: int = 0,
normalize_f0: bool = True,
) -> None:
@@ -896,9 +896,9 @@ class EnergyDataset:
def __init__(
self,
- samples: Union[list[list], list[dict]],
+ samples: list[list] | list[dict],
ap: "AudioProcessor",
- cache_path: Optional[str] = None,
+ cache_path: str | None = None,
precompute_num_workers=0,
normalize_energy=True,
) -> None:
diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py
index ff1a76e2c9..3a4605275a 100644
--- a/TTS/tts/datasets/formatters.py
+++ b/TTS/tts/datasets/formatters.py
@@ -5,7 +5,6 @@
import xml.etree.ElementTree as ET
from glob import glob
from pathlib import Path
-from typing import List
from tqdm import tqdm
@@ -21,7 +20,7 @@ def cml_tts(root_path, meta_file, ignored_speakers=None):
https://github.com/freds0/CML-TTS-Dataset/"""
filepath = os.path.join(root_path, meta_file)
# ensure there are 4 columns for every line
- with open(filepath, "r", encoding="utf8") as f:
+ with open(filepath, encoding="utf8") as f:
lines = f.readlines()
num_cols = len(lines[0].split("|")) # take the first row as reference
for idx, line in enumerate(lines[1:]):
@@ -61,7 +60,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
"""Interal dataset formatter."""
filepath = os.path.join(root_path, meta_file)
# ensure there are 4 columns for every line
- with open(filepath, "r", encoding="utf8") as f:
+ with open(filepath, encoding="utf8") as f:
lines = f.readlines()
num_cols = len(lines[0].split("|")) # take the first row as reference
for idx, line in enumerate(lines[1:]):
@@ -104,7 +103,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "tweb"
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav")
@@ -118,7 +117,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "mozilla"
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = cols[1].strip()
@@ -133,7 +132,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "mozilla"
- with open(txt_file, "r", encoding="ISO 8859-1") as ttf:
+ with open(txt_file, encoding="ISO 8859-1") as ttf:
for line in ttf:
cols = line.strip().split("|")
wav_file = cols[0].strip()
@@ -177,7 +176,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
if speaker_name in ignored_speakers:
continue
logger.info(csv_file)
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
if not meta_files:
@@ -201,7 +200,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "ljspeech"
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
@@ -215,7 +214,7 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
https://keithito.com/LJ-Speech-Dataset/"""
txt_file = os.path.join(root_path, meta_file)
items = []
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
speaker_id = 0
for idx, line in enumerate(ttf):
# 2 samples per speaker to avoid eval split issues
@@ -236,7 +235,7 @@ def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "thorsten"
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
@@ -268,7 +267,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "ruslan"
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
@@ -282,7 +281,7 @@ def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "css10"
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, cols[0])
@@ -296,7 +295,7 @@ def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "nancy"
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
utt_id = line.split()[1]
text = line[line.find('"') + 1 : line.rfind('"') - 1]
@@ -309,7 +308,7 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
"""Normalize the common voice meta data file to TTS format."""
txt_file = os.path.join(root_path, meta_file)
items = []
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
if line.startswith("client_id"):
continue
@@ -338,7 +337,7 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
for meta_file in meta_files:
_meta_file = os.path.basename(meta_file).split(".")[0]
- with open(meta_file, "r", encoding="utf-8") as ttf:
+ with open(meta_file, encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("\t")
file_name = cols[0]
@@ -368,7 +367,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
items = []
speaker_name = "turkish-female"
skipped_files = []
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0].strip() + ".wav")
@@ -386,7 +385,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
"""BRSpeech 3.0 beta"""
txt_file = os.path.join(root_path, meta_file)
items = []
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
if line.startswith("wav_filename"):
continue
@@ -425,7 +424,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
"""
file_ext = "flac"
items = []
- meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
+ meta_files = glob(f"{os.path.join(root_path, 'txt')}/**/*.txt", recursive=True)
for meta_file in meta_files:
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
file_id = txt_file.split(".")[0]
@@ -433,7 +432,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
- with open(meta_file, "r", encoding="utf-8") as file_text:
+ with open(meta_file, encoding="utf-8") as file_text:
text = file_text.readlines()[0]
# p280 has no mic2 recordings
if speaker_id == "p280":
@@ -452,7 +451,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
items = []
- meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
+ meta_files = glob(f"{os.path.join(root_path, 'txt')}/**/*.txt", recursive=True)
for meta_file in meta_files:
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
file_id = txt_file.split(".")[0]
@@ -460,7 +459,7 @@ def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=Non
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
- with open(meta_file, "r", encoding="utf-8") as file_text:
+ with open(meta_file, encoding="utf-8") as file_text:
text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append(
@@ -482,7 +481,7 @@ def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-ar
os.path.dirname(wav_file), "txt", os.path.basename(wav_file).replace(".wav", ".txt")
)
if os.path.exists(txt_file) and os.path.exists(wav_file):
- with open(txt_file, "r", encoding="utf-8") as file_text:
+ with open(txt_file, encoding="utf-8") as file_text:
text = file_text.readlines()[0]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items
@@ -500,7 +499,7 @@ def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, igno
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
- with open(meta_file, "r", encoding="utf-8") as file_text:
+ with open(meta_file, encoding="utf-8") as file_text:
text = file_text.readline().replace("\n", "")
# ignore sentences that contains digits
if ignore_digits_sentences and any(map(str.isdigit, text)):
@@ -513,7 +512,7 @@ def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, igno
def mls(root_path, meta_files=None, ignored_speakers=None):
"""http://www.openslr.org/94/"""
items = []
- with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta:
+ with open(os.path.join(root_path, meta_files), encoding="utf-8") as meta:
for line in meta:
file, text = line.split("\t")
text = text[:-1]
@@ -553,7 +552,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
# if not exists meta file, crawl recursively for 'wav' files
if meta_file is not None:
- with open(str(meta_file), "r", encoding="utf-8") as f:
+ with open(str(meta_file), encoding="utf-8") as f:
return [x.strip().split("|") for x in f.readlines()]
elif not cache_to.exists():
@@ -575,7 +574,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
if cnt < expected_count:
raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}")
- with open(str(cache_to), "r", encoding="utf-8") as f:
+ with open(str(cache_to), encoding="utf-8") as f:
return [x.strip().split("|") for x in f.readlines()]
@@ -583,7 +582,7 @@ def emotion(root_path, meta_file, ignored_speakers=None):
"""Generic emotion dataset"""
txt_file = os.path.join(root_path, meta_file)
items = []
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
if line.startswith("file_path"):
continue
@@ -601,7 +600,7 @@ def emotion(root_path, meta_file, ignored_speakers=None):
return items
-def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylint: disable=unused-argument
+def baker(root_path: str, meta_file: str, **kwargs) -> list[list[str]]: # pylint: disable=unused-argument
"""Normalizes the Baker meta data file to TTS format
Args:
@@ -613,7 +612,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "baker"
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name)
@@ -626,7 +625,7 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "kokoro"
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
@@ -640,7 +639,7 @@ def kss(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "kss"
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, cols[0])
@@ -653,7 +652,7 @@ def bel_tts_formatter(root_path, meta_file, **kwargs): # pylint: disable=unused
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "bel_tts"
- with open(txt_file, "r", encoding="utf-8") as ttf:
+ with open(txt_file, encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, cols[0])
diff --git a/TTS/tts/layers/bark/hubert/kmeans_hubert.py b/TTS/tts/layers/bark/hubert/kmeans_hubert.py
index ade84794eb..87be97d5d1 100644
--- a/TTS/tts/layers/bark/hubert/kmeans_hubert.py
+++ b/TTS/tts/layers/bark/hubert/kmeans_hubert.py
@@ -7,7 +7,6 @@
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
-
import torch
from einops import pack, unpack
from torch import nn
diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py
index 65c7800dcf..457a20ea28 100644
--- a/TTS/tts/layers/bark/inference_funcs.py
+++ b/TTS/tts/layers/bark/inference_funcs.py
@@ -2,7 +2,6 @@
import os
import re
from glob import glob
-from typing import Dict, List, Optional, Tuple
import librosa
import numpy as np
@@ -34,9 +33,9 @@ def _normalize_whitespace(text):
return re.sub(r"\s+", " ", text).strip()
-def get_voices(extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-default-value
+def get_voices(extra_voice_dirs: list[str] = []): # pylint: disable=dangerous-default-value
dirs = extra_voice_dirs
- voices: Dict[str, List[str]] = {}
+ voices: dict[str, list[str]] = {}
for d in dirs:
subs = os.listdir(d)
for sub in subs:
@@ -49,7 +48,7 @@ def get_voices(extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-d
return voices
-def load_npz(npz_file: str) -> Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
+def load_npz(npz_file: str) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
x_history = np.load(npz_file)
semantic = x_history["semantic_prompt"]
coarse = x_history["coarse_prompt"]
@@ -58,10 +57,8 @@ def load_npz(npz_file: str) -> Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64
def load_voice(
- model, voice: str, extra_voice_dirs: List[str] = []
-) -> Tuple[
- Optional[npt.NDArray[np.int64]], Optional[npt.NDArray[np.int64]], Optional[npt.NDArray[np.int64]]
-]: # pylint: disable=dangerous-default-value
+ model, voice: str, extra_voice_dirs: list[str] = []
+) -> tuple[npt.NDArray[np.int64] | None, npt.NDArray[np.int64] | None, npt.NDArray[np.int64] | None]: # pylint: disable=dangerous-default-value
if voice == "random":
return None, None, None
@@ -206,8 +203,8 @@ def generate_text_semantic(
semantic_history = None
encoded_text = np.array(_tokenize(model.tokenizer, text)) + model.config.TEXT_ENCODING_OFFSET
if len(encoded_text) > 256:
- p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
- logger.warning(f"warning, text too long, lopping of last {p}%")
+ p = (len(encoded_text) - 256) / len(encoded_text) * 100
+ logger.warning("warning, text too long, lopping of last %.1f%%", p)
encoded_text = encoded_text[:256]
encoded_text = np.pad(
encoded_text,
diff --git a/TTS/tts/layers/bark/load_model.py b/TTS/tts/layers/bark/load_model.py
index c1e0d006cb..dcec5b5bbc 100644
--- a/TTS/tts/layers/bark/load_model.py
+++ b/TTS/tts/layers/bark/load_model.py
@@ -88,7 +88,7 @@ def clear_cuda_cache():
def load_model(ckpt_path, device, config, model_type="text"):
- logger.info(f"loading {model_type} model from {ckpt_path}...")
+ logger.info("loading %s model from %s...", model_type, ckpt_path)
if device == "cpu":
logger.warning("No GPU being used. Careful, Inference might be extremely slow!")
@@ -108,10 +108,10 @@ def load_model(ckpt_path, device, config, model_type="text"):
and os.path.exists(ckpt_path)
and _md5(ckpt_path) != config.REMOTE_MODEL_PATHS[model_type]["checksum"]
):
- logger.warning(f"found outdated {model_type} model, removing...")
+ logger.warning("found outdated %s model, removing...", model_type)
os.remove(ckpt_path)
if not os.path.exists(ckpt_path):
- logger.info(f"{model_type} model not found, downloading...")
+ logger.info("%s model not found, downloading...", model_type)
# The URL in the config is a 404 and needs to be fixed
download_url = config.REMOTE_MODEL_PATHS[model_type]["path"].replace("tree", "resolve")
_download(download_url, ckpt_path, config.CACHE_DIR)
@@ -150,7 +150,7 @@ def load_model(ckpt_path, device, config, model_type="text"):
model.load_state_dict(state_dict, strict=False)
n_params = model.get_num_params()
val_loss = checkpoint["best_val_loss"].item()
- logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
+ logger.info("model loaded: %.1fM params, %.3f loss", n_params / 1e6, val_loss)
model.eval()
model.to(device)
del checkpoint, state_dict
diff --git a/TTS/tts/layers/bark/model.py b/TTS/tts/layers/bark/model.py
index 54a9cecec0..4850d0a88b 100644
--- a/TTS/tts/layers/bark/model.py
+++ b/TTS/tts/layers/bark/model.py
@@ -175,9 +175,9 @@ def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use
assert idx.shape[1] >= 256 + 256 + 1
t = idx.shape[1] - 256
else:
- assert (
- t <= self.config.block_size
- ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
+ assert t <= self.config.block_size, (
+ f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
+ )
# forward the GPT model itself
if merge_context:
diff --git a/TTS/tts/layers/bark/model_fine.py b/TTS/tts/layers/bark/model_fine.py
index 29126b41ab..20f54d2152 100644
--- a/TTS/tts/layers/bark/model_fine.py
+++ b/TTS/tts/layers/bark/model_fine.py
@@ -101,9 +101,9 @@ def __init__(self, config):
def forward(self, pred_idx, idx):
device = idx.device
b, t, codes = idx.size()
- assert (
- t <= self.config.block_size
- ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
+ assert t <= self.config.block_size, (
+ f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
+ )
assert pred_idx > 0, "cannot predict 0th codebook"
assert codes == self.n_codes_total, (b, t, codes)
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
diff --git a/TTS/tts/layers/delightful_tts/acoustic_model.py b/TTS/tts/layers/delightful_tts/acoustic_model.py
index 2aa82c9a88..9110ff5fd0 100644
--- a/TTS/tts/layers/delightful_tts/acoustic_model.py
+++ b/TTS/tts/layers/delightful_tts/acoustic_model.py
@@ -1,6 +1,6 @@
### credit: https://github.com/dunky11/voicesmith
import logging
-from typing import Callable, Dict, Tuple
+from collections.abc import Callable
import torch
import torch.nn.functional as F
@@ -177,7 +177,7 @@ def init_multispeaker(self, args: Coqpit): # pylint: disable=unused-argument
self._init_d_vector()
@staticmethod
- def _set_cond_input(aux_input: Dict):
+ def _set_cond_input(aux_input: dict):
"""Set the speaker conditioning input based on the multi-speaker mode."""
sid, g, lid, durations = None, None, None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
@@ -194,11 +194,11 @@ def _set_cond_input(aux_input: Dict):
return sid, g, lid, durations
- def get_aux_input(self, aux_input: Dict):
+ def get_aux_input(self, aux_input: dict):
sid, g, lid, _ = self._set_cond_input(aux_input)
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
- def _set_speaker_input(self, aux_input: Dict):
+ def _set_speaker_input(self, aux_input: dict):
d_vectors = aux_input.get("d_vectors", None)
speaker_ids = aux_input.get("speaker_ids", None)
@@ -237,7 +237,7 @@ def _forward_aligner(
x_mask: torch.IntTensor,
y_mask: torch.IntTensor,
attn_priors: torch.FloatTensor,
- ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
+ ) -> tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Aligner forward pass.
1. Compute a mask to apply to the attention map.
@@ -298,7 +298,7 @@ def forward(
use_ground_truth: bool = True,
d_vectors: torch.Tensor = None,
speaker_idx: torch.Tensor = None,
- ) -> Dict[str, torch.Tensor]:
+ ) -> dict[str, torch.Tensor]:
sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable
{"d_vectors": d_vectors, "speaker_ids": speaker_idx}
) # pylint: disable=unused-variable
diff --git a/TTS/tts/layers/delightful_tts/conv_layers.py b/TTS/tts/layers/delightful_tts/conv_layers.py
index 1d5139571e..5cf41d4ff6 100644
--- a/TTS/tts/layers/delightful_tts/conv_layers.py
+++ b/TTS/tts/layers/delightful_tts/conv_layers.py
@@ -1,11 +1,9 @@
-from typing import Tuple
-
import torch
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F
-def calc_same_padding(kernel_size: int) -> Tuple[int, int]:
+def calc_same_padding(kernel_size: int) -> tuple[int, int]:
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)
@@ -52,7 +50,7 @@ def __init__(
w_init_gain="linear",
use_weight_norm=False,
):
- super(ConvNorm, self).__init__() # pylint: disable=super-with-arguments
+ super().__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
@@ -94,7 +92,7 @@ def __init__(
lstm_type="bilstm",
use_linear=True,
):
- super(ConvLSTMLinear, self).__init__() # pylint: disable=super-with-arguments
+ super().__init__()
self.out_dim = out_dim
self.lstm_type = lstm_type
self.use_linear = use_linear
diff --git a/TTS/tts/layers/delightful_tts/encoders.py b/TTS/tts/layers/delightful_tts/encoders.py
index bd0c319dc1..31bab8cc97 100644
--- a/TTS/tts/layers/delightful_tts/encoders.py
+++ b/TTS/tts/layers/delightful_tts/encoders.py
@@ -1,5 +1,3 @@
-from typing import List, Tuple, Union
-
import torch
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F
@@ -36,9 +34,9 @@ class ReferenceEncoder(nn.Module):
def __init__(
self,
num_mels: int,
- ref_enc_filters: List[Union[int, int, int, int, int, int]],
+ ref_enc_filters: list[int | int | int | int | int | int],
ref_enc_size: int,
- ref_enc_strides: List[Union[int, int, int, int, int]],
+ ref_enc_strides: list[int | int | int | int | int],
ref_enc_gru_size: int,
):
super().__init__()
@@ -80,7 +78,7 @@ def __init__(
batch_first=True,
)
- def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
inputs --- [N, n_mels, timesteps]
outputs --- [N, E//2]
@@ -120,9 +118,9 @@ class UtteranceLevelProsodyEncoder(nn.Module):
def __init__(
self,
num_mels: int,
- ref_enc_filters: List[Union[int, int, int, int, int, int]],
+ ref_enc_filters: list[int | int | int | int | int | int],
ref_enc_size: int,
- ref_enc_strides: List[Union[int, int, int, int, int]],
+ ref_enc_strides: list[int | int | int | int | int],
ref_enc_gru_size: int,
dropout: float,
n_hidden: int,
@@ -192,9 +190,9 @@ class PhonemeLevelProsodyEncoder(nn.Module):
def __init__(
self,
num_mels: int,
- ref_enc_filters: List[Union[int, int, int, int, int, int]],
+ ref_enc_filters: list[int | int | int | int | int | int],
ref_enc_size: int,
- ref_enc_strides: List[Union[int, int, int, int, int]],
+ ref_enc_strides: list[int | int | int | int | int],
ref_enc_gru_size: int,
dropout: float,
n_hidden: int,
diff --git a/TTS/tts/layers/delightful_tts/energy_adaptor.py b/TTS/tts/layers/delightful_tts/energy_adaptor.py
index ea0d1e4721..d2b4b0ffa8 100644
--- a/TTS/tts/layers/delightful_tts/energy_adaptor.py
+++ b/TTS/tts/layers/delightful_tts/energy_adaptor.py
@@ -1,4 +1,4 @@
-from typing import Callable, Tuple
+from collections.abc import Callable
import torch
import torch.nn as nn # pylint: disable=consider-using-from-import
@@ -59,7 +59,7 @@ def __init__(
def get_energy_embedding_train(
self, x: torch.Tensor, target: torch.Tensor, dr: torch.IntTensor, mask: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Shapes:
x: :math: `[B, T_src, C]`
diff --git a/TTS/tts/layers/delightful_tts/networks.py b/TTS/tts/layers/delightful_tts/networks.py
index 4305022f18..93b65a2a74 100644
--- a/TTS/tts/layers/delightful_tts/networks.py
+++ b/TTS/tts/layers/delightful_tts/networks.py
@@ -1,5 +1,4 @@
import math
-from typing import Tuple
import numpy as np
import torch
@@ -9,7 +8,7 @@
from TTS.tts.layers.delightful_tts.conv_layers import ConvNorm
-def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor:
+def initialize_embeddings(shape: tuple[int]) -> torch.Tensor:
assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..."
# Kaiming initialization
return torch.randn(shape) * np.sqrt(2 / shape[1])
@@ -52,7 +51,7 @@ def __init__(
kernel_size=3,
use_partial_padding=False, # pylint: disable=unused-argument
):
- super(BottleneckLayer, self).__init__() # pylint: disable=super-with-arguments
+ super().__init__()
self.reduction_factor = reduction_factor
reduced_dim = int(in_dim / reduction_factor)
@@ -195,7 +194,7 @@ class STL(nn.Module):
"""
def __init__(self, n_hidden: int, token_num: int):
- super(STL, self).__init__() # pylint: disable=super-with-arguments
+ super().__init__()
num_heads = 1
E = n_hidden
diff --git a/TTS/tts/layers/delightful_tts/pitch_adaptor.py b/TTS/tts/layers/delightful_tts/pitch_adaptor.py
index 9031369e0f..14e751d2e2 100644
--- a/TTS/tts/layers/delightful_tts/pitch_adaptor.py
+++ b/TTS/tts/layers/delightful_tts/pitch_adaptor.py
@@ -1,4 +1,4 @@
-from typing import Callable, Tuple
+from collections.abc import Callable
import torch
import torch.nn as nn # pylint: disable=consider-using-from-import
@@ -58,7 +58,7 @@ def __init__(
def get_pitch_embedding_train(
self, x: torch.Tensor, target: torch.Tensor, dr: torch.IntTensor, mask: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Shapes:
x: :math: `[B, T_src, C]`
diff --git a/TTS/tts/layers/feed_forward/encoder.py b/TTS/tts/layers/feed_forward/encoder.py
index caf939ffc7..2d08f03c2d 100644
--- a/TTS/tts/layers/feed_forward/encoder.py
+++ b/TTS/tts/layers/feed_forward/encoder.py
@@ -143,9 +143,9 @@ def __init__(
elif encoder_type.lower() == "residual_conv_bn":
self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params)
elif encoder_type.lower() == "fftransformer":
- assert (
- in_hidden_channels == out_channels
- ), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'"
+ assert in_hidden_channels == out_channels, (
+ "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'"
+ )
# pylint: disable=unexpected-keyword-arg
self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params)
else:
diff --git a/TTS/tts/layers/generic/aligner.py b/TTS/tts/layers/generic/aligner.py
index baa6f0e9c4..480c48f9a4 100644
--- a/TTS/tts/layers/generic/aligner.py
+++ b/TTS/tts/layers/generic/aligner.py
@@ -1,5 +1,3 @@
-from typing import Tuple
-
import torch
from torch import nn
@@ -68,7 +66,7 @@ def init_layers(self):
def forward(
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
- ) -> Tuple[torch.tensor, torch.tensor]:
+ ) -> tuple[torch.tensor, torch.tensor]:
"""Forward pass of the aligner encoder.
Shapes:
- queries: :math:`[B, C, T_de]`
diff --git a/TTS/tts/layers/generic/pos_encoding.py b/TTS/tts/layers/generic/pos_encoding.py
index 913add0d14..7765e224aa 100644
--- a/TTS/tts/layers/generic/pos_encoding.py
+++ b/TTS/tts/layers/generic/pos_encoding.py
@@ -18,9 +18,7 @@ class PositionalEncoding(nn.Module):
def __init__(self, channels, dropout_p=0.0, max_len=5000, use_scale=False):
super().__init__()
if channels % 2 != 0:
- raise ValueError(
- "Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels)
- )
+ raise ValueError(f"Cannot use sin/cos positional encoding with odd channels (got channels={channels:d})")
self.use_scale = use_scale
if use_scale:
self.scale = torch.nn.Parameter(torch.ones(1))
diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py
index 9b7ecee2ba..2fe9bcc408 100644
--- a/TTS/tts/layers/generic/transformer.py
+++ b/TTS/tts/layers/generic/transformer.py
@@ -70,9 +70,7 @@ def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument
class FFTDurationPredictor:
- def __init__(
- self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None
- ): # pylint: disable=unused-argument
+ def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
self.proj = nn.Linear(in_channels, 1)
diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py
index db62430c9d..1e744d62cf 100644
--- a/TTS/tts/layers/losses.py
+++ b/TTS/tts/layers/losses.py
@@ -814,7 +814,7 @@ def __init__(self, c):
elif c.spec_loss_type == "l1":
self.spec_loss = L1LossMasked(False)
else:
- raise ValueError(" [!] Unknown spec_loss_type {}".format(c.spec_loss_type))
+ raise ValueError(f" [!] Unknown spec_loss_type {c.spec_loss_type}")
if c.duration_loss_type == "mse":
self.dur_loss = MSELossMasked(False)
@@ -823,7 +823,7 @@ def __init__(self, c):
elif c.duration_loss_type == "huber":
self.dur_loss = Huber()
else:
- raise ValueError(" [!] Unknown duration_loss_type {}".format(c.duration_loss_type))
+ raise ValueError(f" [!] Unknown duration_loss_type {c.duration_loss_type}")
if c.model_args.use_aligner:
self.aligner_loss = ForwardSumLoss()
diff --git a/TTS/tts/layers/overflow/common_layers.py b/TTS/tts/layers/overflow/common_layers.py
index 9f77af293c..a477b34f0b 100644
--- a/TTS/tts/layers/overflow/common_layers.py
+++ b/TTS/tts/layers/overflow/common_layers.py
@@ -1,5 +1,4 @@
import logging
-from typing import List, Tuple
import torch
import torch.nn.functional as F
@@ -44,7 +43,7 @@ def __init__(self, num_chars, state_per_phone, in_out_channels=512, n_convolutio
)
self.rnn_state = None
- def forward(self, x: torch.FloatTensor, x_len: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
+ def forward(self, x: torch.FloatTensor, x_len: torch.LongTensor) -> tuple[torch.FloatTensor, torch.LongTensor]:
"""Forward pass to the encoder.
Args:
@@ -110,7 +109,7 @@ class ParameterModel(nn.Module):
def __init__(
self,
- outputnet_size: List[int],
+ outputnet_size: list[int],
input_size: int,
output_size: int,
frame_channels: int,
@@ -152,7 +151,7 @@ def __init__(
encoder_dim: int,
memory_rnn_dim: int,
frame_channels: int,
- outputnet_size: List[int],
+ outputnet_size: list[int],
flat_start_params: dict,
std_floor: float = 1e-2,
):
diff --git a/TTS/tts/layers/overflow/neural_hmm.py b/TTS/tts/layers/overflow/neural_hmm.py
index a12becef03..9142f65e8c 100644
--- a/TTS/tts/layers/overflow/neural_hmm.py
+++ b/TTS/tts/layers/overflow/neural_hmm.py
@@ -1,5 +1,3 @@
-from typing import List
-
import torch
import torch.distributions as tdist
import torch.nn.functional as F
@@ -57,7 +55,7 @@ def __init__(
prenet_dropout: float,
prenet_dropout_at_inference: bool,
memory_rnn_dim: int,
- outputnet_size: List[int],
+ outputnet_size: list[int],
flat_start_params: dict,
std_floor: float,
use_grad_checkpointing: bool = True,
diff --git a/TTS/tts/layers/tacotron/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py
index 32643dfcee..6f33edf3d7 100644
--- a/TTS/tts/layers/tacotron/tacotron.py
+++ b/TTS/tts/layers/tacotron/tacotron.py
@@ -1,4 +1,3 @@
-# coding: utf-8
# adapted from https://github.com/r9y9/tacotron_pytorch
import logging
diff --git a/TTS/tts/layers/tortoise/arch_utils.py b/TTS/tts/layers/tortoise/arch_utils.py
index 1bbf676393..508699fee3 100644
--- a/TTS/tts/layers/tortoise/arch_utils.py
+++ b/TTS/tts/layers/tortoise/arch_utils.py
@@ -6,7 +6,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
-from transformers import LogitsWarper
+from transformers import LogitsProcessor
from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
@@ -101,9 +101,9 @@ def __init__(
if num_head_channels == -1:
self.num_heads = num_heads
else:
- assert (
- channels % num_head_channels == 0
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ assert channels % num_head_channels == 0, (
+ f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ )
self.num_heads = channels // num_head_channels
self.norm = normalization(channels)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
@@ -292,7 +292,7 @@ def forward(self, x, **kwargs):
return h
-class TypicalLogitsWarper(LogitsWarper):
+class TypicalLogitsWarper(LogitsProcessor):
def __init__(
self,
mass: float = 0.9,
diff --git a/TTS/tts/layers/tortoise/audio_utils.py b/TTS/tts/layers/tortoise/audio_utils.py
index c67ee6c44b..6bbe6c389c 100644
--- a/TTS/tts/layers/tortoise/audio_utils.py
+++ b/TTS/tts/layers/tortoise/audio_utils.py
@@ -1,7 +1,6 @@
import logging
import os
from glob import glob
-from typing import Dict, List
import librosa
import numpy as np
@@ -88,9 +87,9 @@ def normalize_tacotron_mel(mel):
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
-def get_voices(extra_voice_dirs: List[str] = []):
+def get_voices(extra_voice_dirs: list[str] = []):
dirs = extra_voice_dirs
- voices: Dict[str, List[str]] = {}
+ voices: dict[str, list[str]] = {}
for d in dirs:
subs = os.listdir(d)
for sub in subs:
@@ -100,7 +99,7 @@ def get_voices(extra_voice_dirs: List[str] = []):
return voices
-def load_voice(voice: str, extra_voice_dirs: List[str] = []):
+def load_voice(voice: str, extra_voice_dirs: list[str] = []):
if voice == "random":
return None, None
@@ -116,7 +115,7 @@ def load_voice(voice: str, extra_voice_dirs: List[str] = []):
return conds, None
-def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
+def load_voices(voices: list[str], extra_voice_dirs: list[str] = []):
latents = []
clips = []
for voice in voices:
@@ -126,14 +125,14 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
return None, None
clip, latent = load_voice(voice, extra_voice_dirs)
if latent is None:
- assert (
- len(latents) == 0
- ), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
+ assert len(latents) == 0, (
+ "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
+ )
clips.extend(clip)
elif clip is None:
- assert (
- len(clips) == 0
- ), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
+ assert len(clips) == 0, (
+ "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
+ )
latents.append(latent)
if len(latents) == 0:
return clips, None
diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py
index 00c884e973..eaeb2a03c1 100644
--- a/TTS/tts/layers/tortoise/autoregressive.py
+++ b/TTS/tts/layers/tortoise/autoregressive.py
@@ -1,7 +1,6 @@
# AGPL: a notification must be added stating that changes have been made to that file.
import functools
import random
-from typing import Optional
import torch
import torch.nn as nn
@@ -609,9 +608,9 @@ def inference_speech(
if input_tokens is None:
inputs = fake_inputs
else:
- assert (
- num_return_sequences % input_tokens.shape[0] == 0
- ), "The number of return sequences must be divisible by the number of input sequences"
+ assert num_return_sequences % input_tokens.shape[0] == 0, (
+ "The number of return sequences must be divisible by the number of input sequences"
+ )
fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
inputs = torch.cat([fake_inputs, input_tokens], dim=1)
@@ -640,8 +639,8 @@ def inference_speech(
def _prepare_attention_mask_for_generation(
inputs: torch.Tensor,
- pad_token_id: Optional[torch.Tensor],
- eos_token_id: Optional[torch.Tensor],
+ pad_token_id: torch.Tensor | None,
+ eos_token_id: torch.Tensor | None,
) -> torch.LongTensor:
# No information for attention mask inference -> return default attention mask
default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
diff --git a/TTS/tts/layers/tortoise/diffusion.py b/TTS/tts/layers/tortoise/diffusion.py
index 2b29091b44..cfb8fa800d 100644
--- a/TTS/tts/layers/tortoise/diffusion.py
+++ b/TTS/tts/layers/tortoise/diffusion.py
@@ -653,7 +653,7 @@ def p_sample_loop_progressive(
"""
if device is None:
device = next(model.parameters()).device
- assert isinstance(shape, (tuple, list))
+ assert isinstance(shape, tuple | list)
if noise is not None:
img = noise
else:
@@ -805,7 +805,7 @@ def ddim_sample_loop_progressive(
"""
if device is None:
device = next(model.parameters()).device
- assert isinstance(shape, (tuple, list))
+ assert isinstance(shape, tuple | list)
if noise is not None:
img = noise
else:
diff --git a/TTS/tts/layers/tortoise/dpm_solver.py b/TTS/tts/layers/tortoise/dpm_solver.py
index 6a1d8ff784..c8892d456a 100644
--- a/TTS/tts/layers/tortoise/dpm_solver.py
+++ b/TTS/tts/layers/tortoise/dpm_solver.py
@@ -98,9 +98,7 @@ def __init__(
if schedule not in ["discrete", "linear", "cosine"]:
raise ValueError(
- "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
- schedule
- )
+ f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'"
)
self.schedule = schedule
@@ -150,7 +148,7 @@ def marginal_log_mean_coeff(self, t):
t.reshape((-1, 1)),
self.t_array.to(t.device),
self.log_alpha_array.to(t.device),
- ).reshape((-1))
+ ).reshape(-1)
elif self.schedule == "linear":
return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
elif self.schedule == "cosine":
@@ -447,7 +445,7 @@ def correcting_xt_fn(xt, t, step):
Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
"""
- self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
+ self.model = lambda x, t: model_fn(x, t.expand(x.shape[0]))
self.noise_schedule = noise_schedule
assert algorithm_type in ["dpmsolver", "dpmsolver++"]
self.algorithm_type = algorithm_type
@@ -527,7 +525,7 @@ def get_time_steps(self, skip_type, t_T, t_0, N, device):
return t
else:
raise ValueError(
- "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
+ f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'"
)
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
@@ -565,41 +563,21 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
- orders = [
- 3,
- ] * (
- K - 2
- ) + [2, 1]
+ orders = [3] * (K - 2) + [2, 1]
elif steps % 3 == 1:
- orders = [
- 3,
- ] * (
- K - 1
- ) + [1]
+ orders = [3] * (K - 1) + [1]
else:
- orders = [
- 3,
- ] * (
- K - 1
- ) + [2]
+ orders = [3] * (K - 1) + [2]
elif order == 2:
if steps % 2 == 0:
K = steps // 2
- orders = [
- 2,
- ] * K
+ orders = [2] * K
else:
K = steps // 2 + 1
- orders = [
- 2,
- ] * (
- K - 1
- ) + [1]
+ orders = [2] * (K - 1) + [1]
elif order == 1:
K = 1
- orders = [
- 1,
- ] * steps
+ orders = [1] * steps
else:
raise ValueError("'order' must be '1' or '2' or '3'.")
if skip_type == "logSNR":
@@ -607,15 +585,7 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
else:
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
- torch.cumsum(
- torch.tensor(
- [
- 0,
- ]
- + orders
- ),
- 0,
- ).to(device)
+ torch.cumsum(torch.tensor([0] + orders), 0).to(device)
]
return timesteps_outer, orders
@@ -693,7 +663,7 @@ def singlestep_dpm_solver_second_update(
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpmsolver", "taylor"]:
- raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
+ raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
if r1 is None:
r1 = 0.5
ns = self.noise_schedule
@@ -790,7 +760,7 @@ def singlestep_dpm_solver_third_update(
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpmsolver", "taylor"]:
- raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
+ raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
if r1 is None:
r1 = 1.0 / 3.0
if r2 is None:
@@ -913,7 +883,7 @@ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t,
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpmsolver", "taylor"]:
- raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
+ raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
ns = self.noise_schedule
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
@@ -1062,7 +1032,7 @@ def singlestep_dpm_solver_update(
r2=r2,
)
else:
- raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+ raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"):
"""
@@ -1086,7 +1056,7 @@ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order,
elif order == 3:
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
else:
- raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+ raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
def dpm_solver_adaptive(
self,
@@ -1150,8 +1120,8 @@ def higher_update(x, s, t, **kwargs):
return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
else:
- raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
- while torch.abs((s - t_0)).mean() > t_err:
+ raise ValueError(f"For adaptive step size solver, order must be 2 or 3, got {order}")
+ while torch.abs(s - t_0).mean() > t_err:
t = ns.inverse_lambda(lambda_s + h)
x_lower, lower_noise_kwargs = lower_update(x, s, t)
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
@@ -1219,9 +1189,9 @@ def inverse(
"""
t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start
t_T = self.noise_schedule.T if t_end is None else t_end
- assert (
- t_0 > 0 and t_T > 0
- ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
+ assert t_0 > 0 and t_T > 0, (
+ "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
+ )
return self.sample(
x,
steps=steps,
@@ -1364,9 +1334,9 @@ def sample(
"""
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
- assert (
- t_0 > 0 and t_T > 0
- ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
+ assert t_0 > 0 and t_T > 0, (
+ "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
+ )
if return_intermediate:
assert method in [
"multistep",
@@ -1487,7 +1457,7 @@ def sample(
if return_intermediate:
intermediates.append(x)
else:
- raise ValueError("Got wrong method {}".format(method))
+ raise ValueError(f"Got wrong method {method}")
if denoise_to_zero:
t = torch.ones((1,)).to(device) * t_0
x = self.denoise_to_zero_fn(x, t)
diff --git a/TTS/tts/layers/tortoise/transformer.py b/TTS/tts/layers/tortoise/transformer.py
index ed4d79d4ab..531f294220 100644
--- a/TTS/tts/layers/tortoise/transformer.py
+++ b/TTS/tts/layers/tortoise/transformer.py
@@ -1,4 +1,4 @@
-from typing import TypeVar, Union
+from typing import TypeVar
import torch
import torch.nn.functional as F
@@ -11,7 +11,7 @@
_T = TypeVar("_T")
-def cast_tuple(val: Union[tuple[_T], list[_T], _T], depth: int = 1) -> tuple[_T]:
+def cast_tuple(val: tuple[_T] | list[_T] | _T, depth: int = 1) -> tuple[_T]:
if isinstance(val, list):
return tuple(val)
return val if isinstance(val, tuple) else (val,) * depth
@@ -43,9 +43,9 @@ def route_args(router, args, depth):
class SequentialSequence(nn.Module):
def __init__(self, layers, args_route={}, layer_dropout=0.0):
super().__init__()
- assert all(
- len(route) == len(layers) for route in args_route.values()
- ), "each argument route map must have the same depth as the number of sequential layers"
+ assert all(len(route) == len(layers) for route in args_route.values()), (
+ "each argument route map must have the same depth as the number of sequential layers"
+ )
self.layers = layers
self.args_route = args_route
self.layer_dropout = layer_dropout
diff --git a/TTS/tts/layers/tortoise/vocoder.py b/TTS/tts/layers/tortoise/vocoder.py
index a5200c2673..e7497d8190 100644
--- a/TTS/tts/layers/tortoise/vocoder.py
+++ b/TTS/tts/layers/tortoise/vocoder.py
@@ -1,6 +1,6 @@
+from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
-from typing import Callable, Optional
import torch
import torch.nn as nn
@@ -293,7 +293,7 @@ def __init__(
hop_length=256,
n_mel_channels=100,
):
- super(UnivNetGenerator, self).__init__()
+ super().__init__()
self.mel_channel = n_mel_channels
self.noise_dim = noise_dim
self.hop_length = hop_length
@@ -344,7 +344,7 @@ def forward(self, c, z):
return z
def eval(self, inference=False):
- super(UnivNetGenerator, self).eval()
+ super().eval()
# don't remove weight norm while validation in training loop
if inference:
self.remove_weight_norm()
@@ -378,7 +378,7 @@ def inference(self, c, z=None):
class VocType:
constructor: Callable[[], nn.Module]
model_path: str
- subkey: Optional[str] = None
+ subkey: str | None = None
def optionally_index(self, model_dict):
if self.subkey is not None:
diff --git a/TTS/tts/layers/tortoise/xtransformers.py b/TTS/tts/layers/tortoise/xtransformers.py
index 0892fee19d..b2e74cf118 100644
--- a/TTS/tts/layers/tortoise/xtransformers.py
+++ b/TTS/tts/layers/tortoise/xtransformers.py
@@ -560,9 +560,9 @@ def __init__(
self.rel_pos_bias = rel_pos_bias
if rel_pos_bias:
- assert (
- rel_pos_num_buckets <= rel_pos_max_distance
- ), "number of relative position buckets must be less than the relative position max distance"
+ assert rel_pos_num_buckets <= rel_pos_max_distance, (
+ "number of relative position buckets must be less than the relative position max distance"
+ )
self.rel_pos = RelativePositionBias(
scale=dim_head**0.5,
causal=causal,
@@ -680,9 +680,9 @@ def forward(
del input_mask
if exists(attn_mask):
- assert (
- 2 <= attn_mask.ndim <= 4
- ), "attention mask must have greater than 2 dimensions but less than or equal to 4"
+ assert 2 <= attn_mask.ndim <= 4, (
+ "attention mask must have greater than 2 dimensions but less than or equal to 4"
+ )
if attn_mask.ndim == 2:
attn_mask = rearrange(attn_mask, "i j -> () () i j")
elif attn_mask.ndim == 3:
@@ -790,9 +790,9 @@ def __init__(
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
- assert not (
- alibi_pos_bias and rel_pos_bias
- ), "you can only choose Alibi positional bias or T5 relative positional bias, not both"
+ assert not (alibi_pos_bias and rel_pos_bias), (
+ "you can only choose Alibi positional bias or T5 relative positional bias, not both"
+ )
if alibi_pos_bias:
alibi_num_heads = default(alibi_num_heads, heads)
@@ -922,9 +922,9 @@ def forward(
past_key_values=None,
expected_seq_len=None,
):
- assert not (
- self.cross_attend ^ (exists(context) or exists(full_context))
- ), "context must be passed in if cross_attend is set to True"
+ assert not (self.cross_attend ^ (exists(context) or exists(full_context))), (
+ "context must be passed in if cross_attend is set to True"
+ )
assert context is None or full_context is None, "only one of full_context or context can be provided"
hiddens = []
@@ -940,9 +940,9 @@ def forward(
rotary_pos_emb = None
if exists(self.rotary_pos_emb):
if not self.training and self.causal:
- assert (
- expected_seq_len is not None
- ), "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
+ assert expected_seq_len is not None, (
+ "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
+ )
elif expected_seq_len is None:
expected_seq_len = 0
seq_len = x.shape[1]
diff --git a/TTS/tts/layers/vits/transforms.py b/TTS/tts/layers/vits/transforms.py
index 3cac1b8d6d..da5deea9ef 100644
--- a/TTS/tts/layers/vits/transforms.py
+++ b/TTS/tts/layers/vits/transforms.py
@@ -74,7 +74,7 @@ def unconstrained_rational_quadratic_spline(
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
- raise RuntimeError("{} tails are not implemented.".format(tails))
+ raise RuntimeError(f"{tails} tails are not implemented.")
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py
index 20eff26ecc..4e0f53616d 100644
--- a/TTS/tts/layers/xtts/gpt.py
+++ b/TTS/tts/layers/xtts/gpt.py
@@ -347,12 +347,12 @@ def forward(
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
# 💖 Lovely assertions
- assert (
- max_mel_len <= audio_codes.shape[-1]
- ), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})"
- assert (
- max_text_len <= text_inputs.shape[-1]
- ), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})"
+ assert max_mel_len <= audio_codes.shape[-1], (
+ f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})"
+ )
+ assert max_text_len <= text_inputs.shape[-1], (
+ f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})"
+ )
# Append stop token to text inputs
text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
@@ -454,9 +454,9 @@ def forward(
mel_targets[idx, l + 1 :] = -1
# check if stoptoken is in every row of mel_targets
- assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[
- 0
- ], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row."
+ assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[0], (
+ f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row."
+ )
# ignore the loss for the segment used for conditioning
# coin flip for the segment to be ignored
diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py
index 2f4b54cec1..9343f656e1 100644
--- a/TTS/tts/layers/xtts/stream_generator.py
+++ b/TTS/tts/layers/xtts/stream_generator.py
@@ -4,7 +4,7 @@
import inspect
import random
import warnings
-from typing import Callable, Optional, Union
+from collections.abc import Callable
import numpy as np
import torch
@@ -48,15 +48,15 @@ class NewGenerationMixin(GenerationMixin):
@torch.inference_mode()
def generate( # noqa: PLR0911
self,
- inputs: Optional[torch.Tensor] = None,
- generation_config: Optional[StreamGenerationConfig] = None,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
- synced_gpus: Optional[bool] = False,
+ inputs: torch.Tensor | None = None,
+ generation_config: StreamGenerationConfig | None = None,
+ logits_processor: LogitsProcessorList | None = None,
+ stopping_criteria: StoppingCriteriaList | None = None,
+ prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] | None = None,
+ synced_gpus: bool | None = False,
seed: int = 0,
**kwargs,
- ) -> Union[GenerateOutput, torch.LongTensor]:
+ ) -> GenerateOutput | torch.LongTensor:
r"""
Generates sequences of token ids for models with a language modeling head.
@@ -207,8 +207,8 @@ def generate( # noqa: PLR0911
)
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
- generation_config._pad_token_tensor,
- generation_config._eos_token_tensor,
+ generation_config,
+ model_kwargs,
)
# decoder-only models should use left-padding for generation
@@ -666,19 +666,19 @@ def typeerror():
def sample_stream(
self,
input_ids: torch.LongTensor,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- logits_warper: Optional[LogitsProcessorList] = None,
- max_length: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[Union[int, list[int]]] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_scores: Optional[bool] = None,
- return_dict_in_generate: Optional[bool] = None,
- synced_gpus: Optional[bool] = False,
+ logits_processor: LogitsProcessorList | None = None,
+ stopping_criteria: StoppingCriteriaList | None = None,
+ logits_warper: LogitsProcessorList | None = None,
+ max_length: int | None = None,
+ pad_token_id: int | None = None,
+ eos_token_id: int | list[int] | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ output_scores: bool | None = None,
+ return_dict_in_generate: bool | None = None,
+ synced_gpus: bool | None = False,
**model_kwargs,
- ) -> Union[SampleOutput, torch.LongTensor]:
+ ) -> SampleOutput | torch.LongTensor:
r"""
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
@@ -953,7 +953,6 @@ def init_stream_support():
def _get_logits_warper(generation_config: GenerationConfig) -> LogitsProcessorList:
-
warpers = LogitsProcessorList()
if generation_config.temperature is not None and generation_config.temperature != 1.0:
diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py
index fec8358deb..ef4162a1cb 100644
--- a/TTS/tts/layers/xtts/tokenizer.py
+++ b/TTS/tts/layers/xtts/tokenizer.py
@@ -76,7 +76,7 @@ def split_sentence(text, lang, text_split_length=250):
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = {
"en": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
@@ -99,7 +99,7 @@ def split_sentence(text, lang, text_split_length=250):
]
],
"es": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("sra", "señora"),
("sr", "señor"),
@@ -112,7 +112,7 @@ def split_sentence(text, lang, text_split_length=250):
]
],
"fr": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("mme", "madame"),
("mr", "monsieur"),
@@ -124,7 +124,7 @@ def split_sentence(text, lang, text_split_length=250):
]
],
"de": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("fr", "frau"),
("dr", "doktor"),
@@ -134,7 +134,7 @@ def split_sentence(text, lang, text_split_length=250):
]
],
"pt": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("sra", "senhora"),
("sr", "senhor"),
@@ -147,7 +147,7 @@ def split_sentence(text, lang, text_split_length=250):
]
],
"it": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
# ("sig.ra", "signora"),
("sig", "signore"),
@@ -159,7 +159,7 @@ def split_sentence(text, lang, text_split_length=250):
]
],
"pl": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("p", "pani"),
("m", "pan"),
@@ -169,19 +169,19 @@ def split_sentence(text, lang, text_split_length=250):
]
],
"ar": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
# There are not many common abbreviations in Arabic as in English.
]
],
"zh": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
]
],
"cs": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("dr", "doktor"), # doctor
("ing", "inženýr"), # engineer
@@ -190,7 +190,7 @@ def split_sentence(text, lang, text_split_length=250):
]
],
"ru": [
- (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\b", re.IGNORECASE), x[1])
for x in [
("г-жа", "госпожа"), # Mrs.
("г-н", "господин"), # Mr.
@@ -199,7 +199,7 @@ def split_sentence(text, lang, text_split_length=250):
]
],
"nl": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("dhr", "de heer"), # Mr.
("mevr", "mevrouw"), # Mrs.
@@ -209,7 +209,7 @@ def split_sentence(text, lang, text_split_length=250):
]
],
"tr": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("b", "bay"), # Mr.
("byk", "büyük"), # büyük
@@ -218,7 +218,7 @@ def split_sentence(text, lang, text_split_length=250):
]
],
"hu": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("dr", "doktor"), # doctor
("b", "bácsi"), # Mr.
@@ -227,13 +227,13 @@ def split_sentence(text, lang, text_split_length=250):
]
],
"ko": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
# Korean doesn't typically use abbreviations in the same way as Latin-based scripts.
]
],
"hi": [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
# Hindi doesn't typically use abbreviations in the same way as Latin-based scripts.
]
@@ -249,7 +249,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
_symbols_multilingual = {
"en": [
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " and "),
("@", " at "),
@@ -261,7 +261,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
]
],
"es": [
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " y "),
("@", " arroba "),
@@ -273,7 +273,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
]
],
"fr": [
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " et "),
("@", " arobase "),
@@ -285,7 +285,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
]
],
"de": [
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " und "),
("@", " at "),
@@ -297,7 +297,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
]
],
"pt": [
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " e "),
("@", " arroba "),
@@ -309,7 +309,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
]
],
"it": [
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " e "),
("@", " chiocciola "),
@@ -321,7 +321,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
]
],
"pl": [
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " i "),
("@", " małpa "),
@@ -334,7 +334,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
],
"ar": [
# Arabic
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " و "),
("@", " على "),
@@ -347,7 +347,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
],
"zh": [
# Chinese
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " 和 "),
("@", " 在 "),
@@ -360,7 +360,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
],
"cs": [
# Czech
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " a "),
("@", " na "),
@@ -373,7 +373,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
],
"ru": [
# Russian
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " и "),
("@", " собака "),
@@ -386,7 +386,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
],
"nl": [
# Dutch
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " en "),
("@", " bij "),
@@ -398,7 +398,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
]
],
"tr": [
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " ve "),
("@", " at "),
@@ -410,7 +410,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
]
],
"hu": [
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " és "),
("@", " kukac "),
@@ -423,7 +423,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
],
"ko": [
# Korean
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " 그리고 "),
("@", " 에 "),
@@ -435,7 +435,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
]
],
"hi": [
- (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ (re.compile(rf"{re.escape(x[0])}", re.IGNORECASE), x[1])
for x in [
("&", " और "),
("@", " ऐट दी रेट "),
@@ -505,7 +505,7 @@ def _expand_decimal_point(m, lang="en"):
def _expand_currency(m, lang="en", currency="USD"):
- amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
+ amount = float(re.sub(r"[^\d.]", "", m.group(0).replace(",", ".")))
full_amount = num2words(amount, to="currency", currency=currency, lang=lang)
and_equivalents = {
diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py
index 9e8e753a61..f1d40f7fe3 100644
--- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py
+++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py
@@ -1,6 +1,5 @@
import logging
from dataclasses import dataclass, field
-from typing import Dict, List, Tuple, Union
import torch
import torch.nn as nn
@@ -24,16 +23,6 @@
logger = logging.getLogger(__name__)
-@dataclass
-class GPTTrainerConfig(XttsConfig):
- lr: float = 5e-06
- training_seed: int = 1
- optimizer_wd_only_on_weights: bool = False
- weighted_loss_attrs: dict = field(default_factory=lambda: {})
- weighted_loss_multipliers: dict = field(default_factory=lambda: {})
- test_sentences: List[dict] = field(default_factory=lambda: [])
-
-
@dataclass
class GPTArgs(XttsArgs):
min_conditioning_length: int = 66150
@@ -52,6 +41,17 @@ class GPTArgs(XttsArgs):
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
+@dataclass
+class GPTTrainerConfig(XttsConfig):
+ lr: float = 5e-06
+ training_seed: int = 1
+ optimizer_wd_only_on_weights: bool = False
+ weighted_loss_attrs: dict = field(default_factory=lambda: {})
+ weighted_loss_multipliers: dict = field(default_factory=lambda: {})
+ test_sentences: list[dict] = field(default_factory=lambda: [])
+ model_args: GPTArgs = field(default_factory=GPTArgs)
+
+
def callback_clearml_load_save(operation_type, model_info):
# return None means skip the file upload/log, returning model_info will continue with the log/upload
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
@@ -222,7 +222,7 @@ def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels
return losses
@torch.inference_mode()
- def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
+ def test_run(self, assets) -> tuple[dict, dict]: # pylint: disable=W0613
test_audios = {}
if self.config.test_sentences:
# init gpt for inference mode
@@ -237,7 +237,7 @@ def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
s_info["language"],
gpt_cond_len=3,
)["wav"]
- test_audios["{}-audio".format(idx)] = wav
+ test_audios[f"{idx}-audio"] = wav
# delete inference layers
del self.xtts.gpt.gpt_inference
@@ -245,11 +245,15 @@ def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
return {"audios": test_audios}
def test_log(
- self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
+ self,
+ outputs: dict,
+ logger: "Logger",
+ assets: dict,
+ steps: int, # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs["audios"], self.args.output_sample_rate)
- def format_batch(self, batch: Dict) -> Dict:
+ def format_batch(self, batch: dict) -> dict:
return batch
@torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction
@@ -351,12 +355,12 @@ def get_sampler(self, dataset: TTSDataset, num_gpus=1):
def get_data_loader(
self,
config: Coqpit,
- assets: Dict,
+ assets: dict,
is_eval: bool,
- samples: Union[List[Dict], List[List]],
+ samples: list[dict] | list[list],
verbose: bool,
num_gpus: int,
- rank: int = None,
+ rank: int | None = None,
) -> "DataLoader": # pylint: disable=W0613
if is_eval and not config.run_eval:
loader = None
@@ -396,7 +400,7 @@ def get_data_loader(
)
return loader
- def get_optimizer(self) -> List:
+ def get_optimizer(self) -> list:
"""Initiate and return the optimizer based on the config parameters."""
# ToDo: deal with multi GPU training
if self.config.optimizer_wd_only_on_weights:
@@ -427,7 +431,7 @@ def get_optimizer(self) -> List:
v.is_norm = isinstance(m, norm_modules)
v.is_emb = isinstance(m, emb_modules)
- fpn = "%s.%s" % (mn, k) if mn else k # full param name
+ fpn = f"{mn}.{k}" if mn else k # full param name
all_param_names.add(fpn)
param_map[fpn] = v
if v.is_bias or v.is_norm or v.is_emb:
@@ -460,7 +464,7 @@ def get_optimizer(self) -> List:
parameters=self.xtts.gpt.parameters(),
)
- def get_scheduler(self, optimizer) -> List:
+ def get_scheduler(self, optimizer) -> list:
"""Set the scheduler for the optimizer.
Args:
@@ -491,7 +495,7 @@ def load_checkpoint(
assert not self.training
@staticmethod
- def init_from_config(config: "GPTTrainerConfig", samples: Union[List[List], List[Dict]] = None):
+ def init_from_config(config: "GPTTrainerConfig", samples: list[list] | list[dict] = None):
"""Initiate model from config
Args:
diff --git a/TTS/tts/layers/xtts/zh_num2words.py b/TTS/tts/layers/xtts/zh_num2words.py
index 69b8dae952..360d9b06c8 100644
--- a/TTS/tts/layers/xtts/zh_num2words.py
+++ b/TTS/tts/layers/xtts/zh_num2words.py
@@ -392,7 +392,7 @@
# ================================================================================ #
# basic class
# ================================================================================ #
-class ChineseChar(object):
+class ChineseChar:
"""
中文字符
每个字符对应简体和繁体,
@@ -420,13 +420,13 @@ class ChineseNumberUnit(ChineseChar):
"""
def __init__(self, power, simplified, traditional, big_s, big_t):
- super(ChineseNumberUnit, self).__init__(simplified, traditional)
+ super().__init__(simplified, traditional)
self.power = power
self.big_s = big_s
self.big_t = big_t
def __str__(self):
- return "10^{}".format(self.power)
+ return f"10^{self.power}"
@classmethod
def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
@@ -447,7 +447,7 @@ def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=Fals
power=pow(2, index + 3), simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]
)
else:
- raise ValueError("Counting type should be in {0} ({1} provided).".format(NUMBERING_TYPES, numbering_type))
+ raise ValueError(f"Counting type should be in {NUMBERING_TYPES} ({numbering_type} provided).")
class ChineseNumberDigit(ChineseChar):
@@ -456,7 +456,7 @@ class ChineseNumberDigit(ChineseChar):
"""
def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
- super(ChineseNumberDigit, self).__init__(simplified, traditional)
+ super().__init__(simplified, traditional)
self.value = value
self.big_s = big_s
self.big_t = big_t
@@ -477,7 +477,7 @@ class ChineseMath(ChineseChar):
"""
def __init__(self, simplified, traditional, symbol, expression=None):
- super(ChineseMath, self).__init__(simplified, traditional)
+ super().__init__(simplified, traditional)
self.symbol = symbol
self.expression = expression
self.big_s = simplified
@@ -487,13 +487,13 @@ def __init__(self, simplified, traditional, symbol, expression=None):
CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
-class NumberSystem(object):
+class NumberSystem:
"""
中文数字系统
"""
-class MathSymbol(object):
+class MathSymbol:
"""
用于中文数字系统的数学符号 (繁/简体), e.g.
positive = ['正', '正']
@@ -507,8 +507,7 @@ def __init__(self, positive, negative, point):
self.point = point
def __iter__(self):
- for v in self.__dict__.values():
- yield v
+ yield from self.__dict__.values()
# class OtherSymbol(object):
@@ -640,7 +639,7 @@ def compute_value(integer_symbols):
int_str = str(compute_value(int_part))
dec_str = "".join([str(d.value) for d in dec_part])
if dec_part:
- return "{0}.{1}".format(int_str, dec_str)
+ return f"{int_str}.{dec_str}"
else:
return int_str
@@ -686,7 +685,7 @@ def get_value(value_string, use_zeros=True):
int_string = int_dec[0]
dec_string = int_dec[1]
else:
- raise ValueError("invalid input num string with more than one dot: {}".format(number_string))
+ raise ValueError(f"invalid input num string with more than one dot: {number_string}")
if use_units and len(int_string) > 1:
result_symbols = get_value(int_string)
@@ -702,7 +701,7 @@ def get_value(value_string, use_zeros=True):
if isinstance(v, CND) and v.value == 2:
next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None
previous_symbol = result_symbols[i - 1] if i > 0 else None
- if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
+ if isinstance(next_symbol, CNU) and isinstance(previous_symbol, CNU | type(None)):
if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
result_symbols[i] = liang
@@ -1166,7 +1165,7 @@ def __call__(self, text):
)
ndone = 0
- with open(args.ifile, "r", encoding="utf8") as istream, open(args.ofile, "w+", encoding="utf8") as ostream:
+ with open(args.ifile, encoding="utf8") as istream, open(args.ofile, "w+", encoding="utf8") as ostream:
if args.format == "tsv":
reader = csv.DictReader(istream, delimiter="\t")
assert "TEXT" in reader.fieldnames
diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py
index ebfa171c80..4746b13ea2 100644
--- a/TTS/tts/models/__init__.py
+++ b/TTS/tts/models/__init__.py
@@ -1,12 +1,11 @@
import logging
-from typing import Dict, List, Union
from TTS.utils.generic_utils import find_module
logger = logging.getLogger(__name__)
-def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS":
+def setup_model(config: "Coqpit", samples: list[list] | list[dict] = None) -> "BaseTTS":
logger.info("Using model: %s", config.model)
# fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None:
diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py
index c1d0cf0aea..c2e29c7100 100644
--- a/TTS/tts/models/align_tts.py
+++ b/TTS/tts/models/align_tts.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import Dict, List, Union
import torch
from coqpit import Coqpit
@@ -233,9 +232,7 @@ def _forward_mdn(self, o_en, y, y_lengths, x_mask):
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask)
return dr_mas, mu, log_sigma, logp
- def forward(
- self, x, x_lengths, y, y_lengths, aux_input={"d_vectors": None}, phase=None
- ): # pylint: disable=unused-argument
+ def forward(self, x, x_lengths, y, y_lengths, aux_input={"d_vectors": None}, phase=None): # pylint: disable=unused-argument
"""
Shapes:
- x: :math:`[B, T_max]`
@@ -352,9 +349,7 @@ def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use
train_audio = ap.inv_melspectrogram(pred_spec.T)
return figures, {"audio": train_audio}
- def train_log(
- self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
- ) -> None: # pylint: disable=no-self-use
+ def train_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: # pylint: disable=no-self-use
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, self.ap.sample_rate)
@@ -367,9 +362,7 @@ def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, s
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
- def load_checkpoint(
- self, config, checkpoint_path, eval=False, cache=False
- ): # pylint: disable=unused-argument, redefined-builtin
+ def load_checkpoint(self, config, checkpoint_path, eval=False, cache=False): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
@@ -403,7 +396,7 @@ def on_epoch_start(self, trainer):
self.phase = self._set_phase(trainer.config, trainer.total_steps_done)
@staticmethod
- def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None):
+ def init_from_config(config: "AlignTTSConfig", samples: list[list] | list[dict] = None):
"""Initiate model from config
Args:
diff --git a/TTS/tts/models/bark.py b/TTS/tts/models/bark.py
index 6a480e6f5c..84814745a2 100644
--- a/TTS/tts/models/bark.py
+++ b/TTS/tts/models/bark.py
@@ -1,7 +1,6 @@
import os
from dataclasses import dataclass
from pathlib import Path
-from typing import Optional
import numpy as np
from coqpit import Coqpit
@@ -65,7 +64,7 @@ def train_step(
def text_to_semantic(
self,
text: str,
- history_prompt: Optional[str] = None,
+ history_prompt: str | None = None,
temp: float = 0.7,
base=None,
allow_early_stop=True,
@@ -95,7 +94,7 @@ def text_to_semantic(
def semantic_to_waveform(
self,
semantic_tokens: np.ndarray,
- history_prompt: Optional[str] = None,
+ history_prompt: str | None = None,
temp: float = 0.7,
base=None,
):
@@ -129,7 +128,7 @@ def semantic_to_waveform(
def generate_audio(
self,
text: str,
- history_prompt: Optional[str] = None,
+ history_prompt: str | None = None,
text_temp: float = 0.7,
waveform_temp: float = 0.7,
base=None,
@@ -191,9 +190,7 @@ def _set_voice_dirs(self, voice_dirs):
return _voice_dirs
# TODO: remove config from synthesize
- def synthesize(
- self, text, config, speaker_id="random", voice_dirs=None, **kwargs
- ): # pylint: disable=unused-argument
+ def synthesize(self, text, config, speaker_id="random", voice_dirs=None, **kwargs): # pylint: disable=unused-argument
"""Synthesize speech with the given input text.
Args:
diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py
index 79cdf1a7d4..05f4ae168d 100644
--- a/TTS/tts/models/base_tacotron.py
+++ b/TTS/tts/models/base_tacotron.py
@@ -1,7 +1,6 @@
import copy
import logging
from abc import abstractmethod
-from typing import Dict, Tuple
import torch
from coqpit import Coqpit
@@ -62,7 +61,7 @@ def __init__(
self.coarse_decoder = None
@staticmethod
- def _format_aux_input(aux_input: Dict) -> Dict:
+ def _format_aux_input(aux_input: dict) -> dict:
"""Set missing fields to their default values"""
if aux_input:
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
@@ -94,9 +93,7 @@ def forward(self):
def inference(self):
pass
- def load_checkpoint(
- self, config, checkpoint_path, eval=False, cache=False
- ): # pylint: disable=unused-argument, redefined-builtin
+ def load_checkpoint(self, config, checkpoint_path, eval=False, cache=False): # pylint: disable=unused-argument, redefined-builtin
"""Load model checkpoint and set up internals.
Args:
@@ -141,7 +138,7 @@ def init_from_config(config: Coqpit):
# TEST AND LOG FUNCTIONS #
##########################
- def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
+ def test_run(self, assets: dict) -> tuple[dict, dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
@@ -169,17 +166,19 @@ def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
use_griffin_lim=True,
do_trim_silence=False,
)
- test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
- test_figures["{}-prediction".format(idx)] = plot_spectrogram(
+ test_audios[f"{idx}-audio"] = outputs_dict["wav"]
+ test_figures[f"{idx}-prediction"] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
)
- test_figures["{}-alignment".format(idx)] = plot_alignment(
- outputs_dict["outputs"]["alignments"], output_fig=False
- )
+ test_figures[f"{idx}-alignment"] = plot_alignment(outputs_dict["outputs"]["alignments"], output_fig=False)
return {"figures": test_figures, "audios": test_audios}
def test_log(
- self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
+ self,
+ outputs: dict,
+ logger: "Logger",
+ assets: dict,
+ steps: int, # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
logger.test_figures(steps, outputs["figures"])
diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py
index 33a75598c9..95cbf5bbf5 100644
--- a/TTS/tts/models/base_tts.py
+++ b/TTS/tts/models/base_tts.py
@@ -1,7 +1,6 @@
import logging
import os
import random
-from typing import Dict, List, Tuple, Union
import torch
import torch.distributed as dist
@@ -79,7 +78,7 @@ def _set_model_args(self, config: Coqpit):
else:
raise ValueError("config must be either a *Config or *Args")
- def init_multispeaker(self, config: Coqpit, data: List = None):
+ def init_multispeaker(self, config: Coqpit, data: list = None):
"""Set up for multi-speaker TTS.
Initialize a speaker embedding layer if needed and define expected embedding
@@ -114,7 +113,7 @@ def init_multispeaker(self, config: Coqpit, data: List = None):
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
- def get_aux_input(self, **kwargs) -> Dict:
+ def get_aux_input(self, **kwargs) -> dict:
"""Prepare and return `aux_input` used by `forward()`"""
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
@@ -165,7 +164,7 @@ def get_aux_input_from_test_sentences(self, sentence_info):
"language_id": language_id,
}
- def format_batch(self, batch: Dict) -> Dict:
+ def format_batch(self, batch: dict) -> dict:
"""Generic batch formatting for `TTSDataset`.
You must override this if you use a custom dataset.
@@ -211,9 +210,9 @@ def format_batch(self, batch: Dict) -> Dict:
extra_frames = dur.sum() - mel_lengths[idx]
largest_idxs = torch.argsort(-dur)[:extra_frames]
dur[largest_idxs] -= 1
- assert (
- dur.sum() == mel_lengths[idx]
- ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
+ assert dur.sum() == mel_lengths[idx], (
+ f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
+ )
durations[idx, : text_lengths[idx]] = dur
# set stop targets wrt reduction factor
@@ -285,12 +284,12 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
def get_data_loader(
self,
config: Coqpit,
- assets: Dict,
+ assets: dict,
is_eval: bool,
- samples: Union[List[Dict], List[List]],
+ samples: list[dict] | list[list],
verbose: bool,
num_gpus: int,
- rank: int = None,
+ rank: int | None = None,
) -> "DataLoader":
if is_eval and not config.run_eval:
loader = None
@@ -366,7 +365,7 @@ def get_data_loader(
def _get_test_aux_input(
self,
- ) -> Dict:
+ ) -> dict:
d_vector = None
if self.config.use_d_vector_file:
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
@@ -383,7 +382,7 @@ def _get_test_aux_input(
}
return aux_inputs
- def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
+ def test_run(self, assets: dict) -> tuple[dict, dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
@@ -414,13 +413,11 @@ def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
use_griffin_lim=True,
do_trim_silence=False,
)
- test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
- test_figures["{}-prediction".format(idx)] = plot_spectrogram(
+ test_audios[f"{idx}-audio"] = outputs_dict["wav"]
+ test_figures[f"{idx}-prediction"] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
)
- test_figures["{}-alignment".format(idx)] = plot_alignment(
- outputs_dict["outputs"]["alignments"], output_fig=False
- )
+ test_figures[f"{idx}-alignment"] = plot_alignment(outputs_dict["outputs"]["alignments"], output_fig=False)
return test_figures, test_audios
def on_init_start(self, trainer):
diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py
index bee008e26f..2d59db74c0 100644
--- a/TTS/tts/models/delightful_tts.py
+++ b/TTS/tts/models/delightful_tts.py
@@ -3,7 +3,6 @@
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
-from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -65,7 +64,7 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
def __init__(
self,
ap,
- samples: Union[List[List], List[Dict]],
+ samples: list[list] | list[dict],
cache_path: str = None,
precompute_num_workers=0,
normalize_f0=True,
@@ -275,15 +274,15 @@ def collate_fn(self, batch):
@dataclass
class VocoderConfig(Coqpit):
resblock_type_decoder: str = "1"
- resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
- resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
- upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
+ resblock_kernel_sizes_decoder: list[int] = field(default_factory=lambda: [3, 7, 11])
+ resblock_dilation_sizes_decoder: list[list[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
+ upsample_rates_decoder: list[int] = field(default_factory=lambda: [8, 8, 2, 2])
upsample_initial_channel_decoder: int = 512
- upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
+ upsample_kernel_sizes_decoder: list[int] = field(default_factory=lambda: [16, 16, 4, 4])
use_spectral_norm_discriminator: bool = False
- upsampling_rates_discriminator: List[int] = field(default_factory=lambda: [4, 4, 4, 4])
- periods_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11])
- pretrained_model_path: Optional[str] = None
+ upsampling_rates_discriminator: list[int] = field(default_factory=lambda: [4, 4, 4, 4])
+ periods_discriminator: list[int] = field(default_factory=lambda: [2, 3, 5, 7, 11])
+ pretrained_model_path: str | None = None
@dataclass
@@ -553,7 +552,7 @@ def forward(
attn_priors: torch.FloatTensor = None,
d_vectors: torch.FloatTensor = None,
speaker_idx: torch.LongTensor = None,
- ) -> Dict:
+ ) -> dict:
"""Model's forward pass.
Args:
@@ -832,9 +831,7 @@ def _log(self, batch, outputs, name_prefix="train"):
audios[f"{name_prefix}/vocoder_audio"] = sample_voice
return figures, audios
- def train_log(
- self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
- ): # pylint: disable=no-self-use, unused-argument
+ def train_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int): # pylint: disable=no-self-use, unused-argument
"""Create visualizations and waveform examples.
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
@@ -1015,7 +1012,7 @@ def synthesize_with_gl(self, text: str, speaker_id, d_vector):
return return_dict
@torch.inference_mode()
- def test_run(self, assets) -> Tuple[Dict, Dict]:
+ def test_run(self, assets) -> tuple[dict, dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
@@ -1041,18 +1038,22 @@ def test_run(self, assets) -> Tuple[Dict, Dict]:
d_vector=aux_inputs["d_vector"],
)
# speaker_name = self.speaker_manager.speaker_names[aux_inputs["speaker_id"]]
- test_audios["{}-audio".format(idx)] = outputs["wav"].T
- test_audios["{}-audio_encoder".format(idx)] = outputs_gl["wav"].T
- test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
+ test_audios[f"{idx}-audio"] = outputs["wav"].T
+ test_audios[f"{idx}-audio_encoder"] = outputs_gl["wav"].T
+ test_figures[f"{idx}-alignment"] = plot_alignment(outputs["alignments"], output_fig=False)
return {"figures": test_figures, "audios": test_audios}
def test_log(
- self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
+ self,
+ outputs: dict,
+ logger: "Logger",
+ assets: dict,
+ steps: int, # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs["audios"], self.config.audio.sample_rate)
logger.test_figures(steps, outputs["figures"])
- def format_batch(self, batch: Dict) -> Dict:
+ def format_batch(self, batch: dict) -> dict:
"""Compute speaker, langugage IDs and d_vector for the batch if necessary."""
speaker_ids = None
d_vectors = None
@@ -1160,12 +1161,12 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
def get_data_loader(
self,
config: Coqpit,
- assets: Dict,
+ assets: dict,
is_eval: bool,
- samples: Union[List[Dict], List[List]],
+ samples: list[dict] | list[list],
verbose: bool,
num_gpus: int,
- rank: int = None,
+ rank: int | None = None,
) -> "DataLoader":
if is_eval and not config.run_eval:
loader = None
@@ -1217,7 +1218,7 @@ def get_data_loader(
def get_criterion(self):
return [VitsDiscriminatorLoss(self.config), DelightfulTTSLoss(self.config)]
- def get_optimizer(self) -> List:
+ def get_optimizer(self) -> list:
"""Initiate and return the GAN optimizers based on the config parameters.
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
Returns:
@@ -1232,7 +1233,7 @@ def get_optimizer(self) -> List:
)
return [optimizer_disc, optimizer_gen]
- def get_lr(self) -> List:
+ def get_lr(self) -> list:
"""Set the initial learning rates for each optimizer.
Returns:
@@ -1240,7 +1241,7 @@ def get_lr(self) -> List:
"""
return [self.config.lr_disc, self.config.lr_gen]
- def get_scheduler(self, optimizer) -> List:
+ def get_scheduler(self, optimizer) -> list:
"""Set the schedulers for each optimizer.
Args:
@@ -1259,9 +1260,7 @@ def on_epoch_end(self, trainer): # pylint: disable=unused-argument
self.energy_scaler.eval()
@staticmethod
- def init_from_config(
- config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None
- ): # pylint: disable=unused-argument
+ def init_from_config(config: "DelightfulTTSConfig", samples: list[list] | list[dict] = None): # pylint: disable=unused-argument
"""Initiate model from config
Args:
diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py
index 03166fa8c0..497ac3f63a 100644
--- a/TTS/tts/models/forward_tts.py
+++ b/TTS/tts/models/forward_tts.py
@@ -1,6 +1,5 @@
import logging
from dataclasses import dataclass, field
-from typing import Dict, List, Tuple, Union
import torch
from coqpit import Coqpit
@@ -333,7 +332,7 @@ def format_durations(self, o_dr_log, x_mask):
def _forward_encoder(
self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Encoding forward pass.
1. Embed speaker IDs if multi-speaker mode.
@@ -381,7 +380,7 @@ def _forward_decoder(
x_mask: torch.FloatTensor,
y_lengths: torch.IntTensor,
g: torch.FloatTensor,
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
"""Decoding forward pass.
1. Compute the decoder output mask
@@ -415,7 +414,7 @@ def _forward_pitch_predictor(
x_mask: torch.IntTensor,
pitch: torch.FloatTensor = None,
dr: torch.IntTensor = None,
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
"""Pitch predictor forward pass.
1. Predict pitch from encoder outputs.
@@ -451,7 +450,7 @@ def _forward_energy_predictor(
x_mask: torch.IntTensor,
energy: torch.FloatTensor = None,
dr: torch.IntTensor = None,
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
"""Energy predictor forward pass.
1. Predict energy from encoder outputs.
@@ -483,7 +482,7 @@ def _forward_energy_predictor(
def _forward_aligner(
self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor
- ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
+ ) -> tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Aligner forward pass.
1. Compute a mask to apply to the attention map.
@@ -522,7 +521,7 @@ def _forward_aligner(
alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
- def _set_speaker_input(self, aux_input: Dict):
+ def _set_speaker_input(self, aux_input: dict):
d_vectors = aux_input.get("d_vectors", None)
speaker_ids = aux_input.get("speaker_ids", None)
@@ -544,8 +543,8 @@ def forward(
dr: torch.IntTensor = None,
pitch: torch.FloatTensor = None,
energy: torch.FloatTensor = None,
- aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument
- ) -> Dict:
+ aux_input: dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument
+ ) -> dict:
"""Model's forward pass.
Args:
@@ -771,9 +770,7 @@ def _create_logs(self, batch, outputs, ap):
train_audio = ap.inv_melspectrogram(pred_spec.T)
return figures, {"audio": train_audio}
- def train_log(
- self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
- ) -> None: # pylint: disable=no-self-use
+ def train_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: # pylint: disable=no-self-use
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, self.ap.sample_rate)
@@ -786,9 +783,7 @@ def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, s
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
- def load_checkpoint(
- self, config, checkpoint_path, eval=False, cache=False
- ): # pylint: disable=unused-argument, redefined-builtin
+ def load_checkpoint(self, config, checkpoint_path, eval=False, cache=False): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
@@ -805,7 +800,7 @@ def on_train_step_start(self, trainer):
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
@staticmethod
- def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):
+ def init_from_config(config: "ForwardTTSConfig", samples: list[list] | list[dict] = None):
"""Initiate model from config
Args:
diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py
index aaf5190ada..5d03b53dc6 100644
--- a/TTS/tts/models/glow_tts.py
+++ b/TTS/tts/models/glow_tts.py
@@ -1,6 +1,5 @@
import logging
import math
-from typing import Dict, List, Tuple, Union
import torch
from coqpit import Coqpit
@@ -125,9 +124,9 @@ def init_multispeaker(self, config: Coqpit):
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
)
if self.speaker_manager is not None:
- assert (
- config.d_vector_dim == self.speaker_manager.embedding_dim
- ), " [!] d-vector dimension mismatch b/w config and speaker manager."
+ assert config.d_vector_dim == self.speaker_manager.embedding_dim, (
+ " [!] d-vector dimension mismatch b/w config and speaker manager."
+ )
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
logger.info("Init speaker_embedding layer.")
@@ -162,7 +161,7 @@ def lock_act_norm_layers(self):
if getattr(f, "set_ddi", False):
f.set_ddi(False)
- def _set_speaker_input(self, aux_input: Dict):
+ def _set_speaker_input(self, aux_input: dict):
if aux_input is None:
d_vectors = None
speaker_ids = None
@@ -179,7 +178,7 @@ def _set_speaker_input(self, aux_input: Dict):
g = speaker_ids if speaker_ids is not None else d_vectors
return g
- def _speaker_embedding(self, aux_input: Dict) -> Union[torch.tensor, None]:
+ def _speaker_embedding(self, aux_input: dict) -> torch.Tensor | None:
g = self._set_speaker_input(aux_input)
# speaker embedding
if g is not None:
@@ -193,9 +192,7 @@ def _speaker_embedding(self, aux_input: Dict) -> Union[torch.tensor, None]:
g = F.normalize(g).unsqueeze(-1) # [b, h, 1]
return g
- def forward(
- self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
- ): # pylint: disable=dangerous-default-value
+ def forward(self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=dangerous-default-value
"""
Args:
x (torch.Tensor):
@@ -319,9 +316,7 @@ def inference_with_MAS(
return outputs
@torch.inference_mode()
- def decoder_inference(
- self, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
- ): # pylint: disable=dangerous-default-value
+ def decoder_inference(self, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=dangerous-default-value
"""
Shapes:
- y: :math:`[B, T, C]`
@@ -342,9 +337,7 @@ def decoder_inference(
return outputs
@torch.inference_mode()
- def inference(
- self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None}
- ): # pylint: disable=dangerous-default-value
+ def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None}): # pylint: disable=dangerous-default-value
x_lengths = aux_input["x_lengths"]
g = self._speaker_embedding(aux_input)
# embedding pass
@@ -457,9 +450,7 @@ def _create_logs(self, batch, outputs, ap):
train_audio = ap.inv_melspectrogram(pred_spec.T)
return figures, {"audio": train_audio}
- def train_log(
- self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
- ) -> None: # pylint: disable=no-self-use
+ def train_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: # pylint: disable=no-self-use
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, self.ap.sample_rate)
@@ -474,7 +465,7 @@ def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, s
logger.eval_audios(steps, audios, self.ap.sample_rate)
@torch.inference_mode()
- def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
+ def test_run(self, assets: dict) -> tuple[dict, dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
@@ -503,11 +494,11 @@ def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
do_trim_silence=False,
)
- test_audios["{}-audio".format(idx)] = outputs["wav"]
- test_figures["{}-prediction".format(idx)] = plot_spectrogram(
+ test_audios[f"{idx}-audio"] = outputs["wav"]
+ test_figures[f"{idx}-prediction"] = plot_spectrogram(
outputs["outputs"]["model_outputs"], self.ap, output_fig=False
)
- test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
+ test_figures[f"{idx}-alignment"] = plot_alignment(outputs["alignments"], output_fig=False)
return test_figures, test_audios
def preprocess(self, y, y_lengths, y_max_length, attn=None):
@@ -522,9 +513,7 @@ def preprocess(self, y, y_lengths, y_max_length, attn=None):
def store_inverse(self):
self.decoder.store_inverse()
- def load_checkpoint(
- self, config, checkpoint_path, eval=False
- ): # pylint: disable=unused-argument, redefined-builtin
+ def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
@@ -543,7 +532,7 @@ def on_train_step_start(self, trainer):
self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps
@staticmethod
- def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None):
+ def init_from_config(config: "GlowTTSConfig", samples: list[list] | list[dict] = None):
"""Initiate model from config
Args:
diff --git a/TTS/tts/models/neuralhmm_tts.py b/TTS/tts/models/neuralhmm_tts.py
index b9a23000a0..2cbf425884 100644
--- a/TTS/tts/models/neuralhmm_tts.py
+++ b/TTS/tts/models/neuralhmm_tts.py
@@ -1,6 +1,5 @@
import logging
import os
-from typing import Dict, List, Union
import torch
from coqpit import Coqpit
@@ -102,7 +101,7 @@ def __init__(
self.register_buffer("mean", torch.tensor(0))
self.register_buffer("std", torch.tensor(1))
- def update_mean_std(self, statistics_dict: Dict):
+ def update_mean_std(self, statistics_dict: dict):
self.mean.data = torch.tensor(statistics_dict["mean"])
self.std.data = torch.tensor(statistics_dict["std"])
@@ -174,10 +173,10 @@ def train_step(self, batch: dict, criterion: nn.Module):
loss_dict.update(self._training_stats(batch))
return outputs, loss_dict
- def eval_step(self, batch: Dict, criterion: nn.Module):
+ def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
- def _format_aux_input(self, aux_input: Dict, default_input_dict):
+ def _format_aux_input(self, aux_input: dict, default_input_dict):
"""Set missing fields to their default value.
Args:
@@ -239,7 +238,7 @@ def get_criterion():
return NLLLoss()
@staticmethod
- def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None):
+ def init_from_config(config: "NeuralhmmTTSConfig", samples: list[list] | list[dict] = None):
"""Initiate model from config
Args:
@@ -346,17 +345,13 @@ def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unus
audio = ap.inv_melspectrogram(inference_output["model_outputs"][0].T.cpu().numpy())
return figures, {"audios": audio}
- def train_log(
- self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
- ): # pylint: disable=unused-argument
+ def train_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int): # pylint: disable=unused-argument
"""Log training progress."""
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, self.ap.sample_rate)
- def eval_log(
- self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int
- ): # pylint: disable=unused-argument
+ def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int): # pylint: disable=unused-argument
"""Compute and log evaluation metrics."""
# Plot model parameters histograms
if isinstance(logger, TensorboardLogger):
@@ -370,7 +365,11 @@ def eval_log(
logger.eval_audios(steps, audios, self.ap.sample_rate)
def test_log(
- self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
+ self,
+ outputs: dict,
+ logger: "Logger",
+ assets: dict,
+ steps: int, # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
logger.test_figures(steps, outputs[0])
diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py
index 10157e43a4..aad2e1f553 100644
--- a/TTS/tts/models/overflow.py
+++ b/TTS/tts/models/overflow.py
@@ -1,6 +1,5 @@
import logging
import os
-from typing import Dict, List, Union
import torch
from coqpit import Coqpit
@@ -116,7 +115,7 @@ def __init__(
self.register_buffer("mean", torch.tensor(0))
self.register_buffer("std", torch.tensor(1))
- def update_mean_std(self, statistics_dict: Dict):
+ def update_mean_std(self, statistics_dict: dict):
self.mean.data = torch.tensor(statistics_dict["mean"])
self.std.data = torch.tensor(statistics_dict["std"])
@@ -188,10 +187,10 @@ def train_step(self, batch: dict, criterion: nn.Module):
loss_dict.update(self._training_stats(batch))
return outputs, loss_dict
- def eval_step(self, batch: Dict, criterion: nn.Module):
+ def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
- def _format_aux_input(self, aux_input: Dict, default_input_dict):
+ def _format_aux_input(self, aux_input: dict, default_input_dict):
"""Set missing fields to their default value.
Args:
@@ -255,7 +254,7 @@ def get_criterion():
return NLLLoss()
@staticmethod
- def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None):
+ def init_from_config(config: "OverFlowConfig", samples: list[list] | list[dict] = None):
"""Initiate model from config
Args:
@@ -363,17 +362,13 @@ def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unus
audio = ap.inv_melspectrogram(inference_output["model_outputs"][0].T.cpu().numpy())
return figures, {"audios": audio}
- def train_log(
- self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
- ): # pylint: disable=unused-argument
+ def train_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int): # pylint: disable=unused-argument
"""Log training progress."""
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, self.ap.sample_rate)
- def eval_log(
- self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int
- ): # pylint: disable=unused-argument
+ def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int): # pylint: disable=unused-argument
"""Compute and log evaluation metrics."""
# Plot model parameters histograms
if isinstance(logger, TensorboardLogger):
@@ -387,7 +382,11 @@ def eval_log(
logger.eval_audios(steps, audios, self.ap.sample_rate)
def test_log(
- self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
+ self,
+ outputs: dict,
+ logger: "Logger",
+ assets: dict,
+ steps: int, # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
logger.test_figures(steps, outputs[0])
diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py
index da85823f3f..59173691f7 100644
--- a/TTS/tts/models/tacotron.py
+++ b/TTS/tts/models/tacotron.py
@@ -1,7 +1,3 @@
-# coding: utf-8
-
-from typing import Dict, List, Tuple, Union
-
import torch
from torch import nn
from trainer.trainer_utils import get_optimizer, get_scheduler
@@ -280,7 +276,7 @@ def before_backward_pass(self, loss_dict, optimizer) -> None:
loss_dict["capacitron_vae_beta_loss"].backward()
optimizer.first_step()
- def train_step(self, batch: Dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]:
+ def train_step(self, batch: dict, criterion: torch.nn.Module) -> tuple[dict, dict]:
"""Perform a single training step by fetching the right set of samples from the batch.
Args:
@@ -332,7 +328,7 @@ def train_step(self, batch: Dict, criterion: torch.nn.Module) -> Tuple[Dict, Dic
loss_dict["align_error"] = align_error
return outputs, loss_dict
- def get_optimizer(self) -> List:
+ def get_optimizer(self) -> list:
if self.use_capacitron_vae:
return CapacitronOptimizer(self.config, self.named_parameters())
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
@@ -380,9 +376,7 @@ def _create_logs(self, batch, outputs, ap):
audio = ap.inv_spectrogram(pred_linear_spec.T)
return figures, {"audio": audio}
- def train_log(
- self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
- ) -> None: # pylint: disable=no-self-use
+ def train_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: # pylint: disable=no-self-use
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, self.ap.sample_rate)
@@ -396,7 +390,7 @@ def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, s
logger.eval_audios(steps, audios, self.ap.sample_rate)
@staticmethod
- def init_from_config(config: "TacotronConfig", samples: Union[List[List], List[Dict]] = None):
+ def init_from_config(config: "TacotronConfig", samples: list[list] | list[dict] = None):
"""Initiate model from config
Args:
diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py
index e2edd4bb5c..e924d82d42 100644
--- a/TTS/tts/models/tacotron2.py
+++ b/TTS/tts/models/tacotron2.py
@@ -1,7 +1,3 @@
-# coding: utf-8
-
-from typing import Dict, List, Union
-
import torch
from torch import nn
from trainer.trainer_utils import get_optimizer, get_scheduler
@@ -309,7 +305,7 @@ def before_backward_pass(self, loss_dict, optimizer) -> None:
loss_dict["capacitron_vae_beta_loss"].backward()
optimizer.first_step()
- def train_step(self, batch: Dict, criterion: torch.nn.Module):
+ def train_step(self, batch: dict, criterion: torch.nn.Module):
"""A single training step. Forward pass and loss computation.
Args:
@@ -360,7 +356,7 @@ def train_step(self, batch: Dict, criterion: torch.nn.Module):
loss_dict["align_error"] = align_error
return outputs, loss_dict
- def get_optimizer(self) -> List:
+ def get_optimizer(self) -> list:
if self.use_capacitron_vae:
return CapacitronOptimizer(self.config, self.named_parameters())
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
@@ -403,9 +399,7 @@ def _create_logs(self, batch, outputs, ap):
audio = ap.inv_melspectrogram(pred_spec.T)
return figures, {"audio": audio}
- def train_log(
- self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
- ) -> None: # pylint: disable=no-self-use
+ def train_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: # pylint: disable=no-self-use
"""Log training progress."""
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
@@ -420,7 +414,7 @@ def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, s
logger.eval_audios(steps, audios, self.ap.sample_rate)
@staticmethod
- def init_from_config(config: "Tacotron2Config", samples: Union[List[List], List[Dict]] = None):
+ def init_from_config(config: "Tacotron2Config", samples: list[list] | list[dict] = None):
"""Initiate model from config
Args:
diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py
index 738e9dd9b3..a42d577676 100644
--- a/TTS/tts/models/tortoise.py
+++ b/TTS/tts/models/tortoise.py
@@ -342,7 +342,6 @@ def __init__(self, config: Coqpit):
else self.args.autoregressive_batch_size
)
self.enable_redaction = self.args.enable_redaction
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.enable_redaction:
self.aligner = Wav2VecAlignment()
@@ -685,9 +684,9 @@ def inference(
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
- assert (
- text_tokens.shape[-1] < 400
- ), "Too much text provided. Break the text up into separate segments and re-try inference."
+ assert text_tokens.shape[-1] < 400, (
+ "Too much text provided. Break the text up into separate segments and re-try inference."
+ )
if voice_samples is not None:
(
diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py
index 135b8e5016..b542030f13 100644
--- a/TTS/tts/models/vits.py
+++ b/TTS/tts/models/vits.py
@@ -4,7 +4,7 @@
from dataclasses import dataclass, field, replace
from itertools import chain
from pathlib import Path
-from typing import Any, Dict, List, Tuple, Union
+from typing import Any
import numpy as np
import torch
@@ -401,12 +401,12 @@ class VitsArgs(Coqpit):
dilation_rate_flow: int = 1
num_layers_flow: int = 4
resblock_type_decoder: str = "1"
- resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
- resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
- upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
+ resblock_kernel_sizes_decoder: list[int] = field(default_factory=lambda: [3, 7, 11])
+ resblock_dilation_sizes_decoder: list[list[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
+ upsample_rates_decoder: list[int] = field(default_factory=lambda: [8, 8, 2, 2])
upsample_initial_channel_decoder: int = 512
- upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
- periods_multi_period_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11])
+ upsample_kernel_sizes_decoder: list[int] = field(default_factory=lambda: [16, 16, 4, 4])
+ periods_multi_period_discriminator: list[int] = field(default_factory=lambda: [2, 3, 5, 7, 11])
use_sdp: bool = True
noise_scale: float = 1.0
inference_noise_scale: float = 0.667
@@ -419,7 +419,7 @@ class VitsArgs(Coqpit):
use_speaker_embedding: bool = False
num_speakers: int = 0
speakers_file: str = None
- d_vector_file: List[str] = None
+ d_vector_file: list[str] = None
speaker_embedding_channels: int = 256
use_d_vector_file: bool = False
d_vector_dim: int = 0
@@ -680,7 +680,7 @@ def on_init_end(self, trainer): # pylint: disable=W0613
raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !")
logger.info("Text Encoder was reinit.")
- def get_aux_input(self, aux_input: Dict):
+ def get_aux_input(self, aux_input: dict):
sid, g, lid, _ = self._set_cond_input(aux_input)
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
@@ -710,7 +710,7 @@ def _freeze_layers(self):
param.requires_grad = False
@staticmethod
- def _set_cond_input(aux_input: Dict):
+ def _set_cond_input(aux_input: dict):
"""Set the speaker conditioning input based on the multi-speaker mode."""
sid, g, lid, durations = None, None, None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
@@ -732,7 +732,7 @@ def _set_cond_input(aux_input: Dict):
return sid, g, lid, durations
- def _set_speaker_input(self, aux_input: Dict):
+ def _set_speaker_input(self, aux_input: dict):
d_vectors = aux_input.get("d_vectors", None)
speaker_ids = aux_input.get("speaker_ids", None)
@@ -805,7 +805,7 @@ def forward( # pylint: disable=dangerous-default-value
y_lengths: torch.tensor,
waveform: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
- ) -> Dict:
+ ) -> dict:
"""Forward pass of the model.
Args:
@@ -1052,8 +1052,8 @@ def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
# speaker embedding
if self.args.use_speaker_embedding and not self.args.use_d_vector_file:
- g_src = self.emb_g(torch.from_numpy((np.array(speaker_cond_src))).unsqueeze(0)).unsqueeze(-1)
- g_tgt = self.emb_g(torch.from_numpy((np.array(speaker_cond_tgt))).unsqueeze(0)).unsqueeze(-1)
+ g_src = self.emb_g(torch.from_numpy(np.array(speaker_cond_src)).unsqueeze(0)).unsqueeze(-1)
+ g_tgt = self.emb_g(torch.from_numpy(np.array(speaker_cond_tgt)).unsqueeze(0)).unsqueeze(-1)
elif not self.args.use_speaker_embedding and self.args.use_d_vector_file:
g_src = F.normalize(speaker_cond_src).unsqueeze(-1)
g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)
@@ -1066,7 +1066,7 @@ def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat)
- def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
+ def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> tuple[dict, dict]:
"""Perform a single training step. Run the model forward pass and compute losses.
Args:
@@ -1186,9 +1186,7 @@ def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unus
)
return figures, audios
- def train_log(
- self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
- ): # pylint: disable=no-self-use
+ def train_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int): # pylint: disable=no-self-use
"""Create visualizations and waveform examples.
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
@@ -1264,7 +1262,7 @@ def get_aux_input_from_test_sentences(self, sentence_info):
}
@torch.inference_mode()
- def test_run(self, assets) -> Tuple[Dict, Dict]:
+ def test_run(self, assets) -> tuple[dict, dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
@@ -1290,17 +1288,21 @@ def test_run(self, assets) -> Tuple[Dict, Dict]:
use_griffin_lim=True,
do_trim_silence=False,
).values()
- test_audios["{}-audio".format(idx)] = wav
- test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.permute(2, 1, 0), output_fig=False)
+ test_audios[f"{idx}-audio"] = wav
+ test_figures[f"{idx}-alignment"] = plot_alignment(alignment.permute(2, 1, 0), output_fig=False)
return {"figures": test_figures, "audios": test_audios}
def test_log(
- self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
+ self,
+ outputs: dict,
+ logger: "Logger",
+ assets: dict,
+ steps: int, # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
logger.test_figures(steps, outputs["figures"])
- def format_batch(self, batch: Dict) -> Dict:
+ def format_batch(self, batch: dict) -> dict:
"""Compute speaker, langugage IDs and d_vector for the batch if necessary."""
speaker_ids = None
language_ids = None
@@ -1364,9 +1366,9 @@ def format_batch_on_device(self, batch):
)
if self.args.encoder_sample_rate:
- assert batch["spec"].shape[2] == int(
- batch["mel"].shape[2] / self.interpolate_factor
- ), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
+ assert batch["spec"].shape[2] == int(batch["mel"].shape[2] / self.interpolate_factor), (
+ f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
+ )
else:
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
@@ -1423,12 +1425,12 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=F
def get_data_loader(
self,
config: Coqpit,
- assets: Dict,
+ assets: dict,
is_eval: bool,
- samples: Union[List[Dict], List[List]],
+ samples: list[dict] | list[list],
verbose: bool,
num_gpus: int,
- rank: int = None,
+ rank: int | None = None,
) -> "DataLoader":
if is_eval and not config.run_eval:
loader = None
@@ -1487,7 +1489,7 @@ def get_data_loader(
)
return loader
- def get_optimizer(self) -> List:
+ def get_optimizer(self) -> list:
"""Initiate and return the GAN optimizers based on the config parameters.
It returns 2 optimizers in a list. First one is for the discriminator
@@ -1505,7 +1507,7 @@ def get_optimizer(self) -> List:
)
return [optimizer0, optimizer1]
- def get_lr(self) -> List:
+ def get_lr(self) -> list:
"""Set the initial learning rates for each optimizer.
Returns:
@@ -1513,7 +1515,7 @@ def get_lr(self) -> List:
"""
return [self.config.lr_disc, self.config.lr_gen]
- def get_scheduler(self, optimizer) -> List:
+ def get_scheduler(self, optimizer) -> list:
"""Set the schedulers for each optimizer.
Args:
@@ -1536,9 +1538,7 @@ def get_criterion(self):
return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)]
- def load_checkpoint(
- self, config, checkpoint_path, eval=False, strict=True, cache=False
- ): # pylint: disable=unused-argument, redefined-builtin
+ def load_checkpoint(self, config, checkpoint_path, eval=False, strict=True, cache=False): # pylint: disable=unused-argument, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
# compat band-aid for the pre-trained models to not use the encoder baked into the model
@@ -1565,9 +1565,7 @@ def load_checkpoint(
self.eval()
assert not self.training
- def load_fairseq_checkpoint(
- self, config, checkpoint_dir, eval=False, strict=True
- ): # pylint: disable=unused-argument, redefined-builtin
+ def load_fairseq_checkpoint(self, config, checkpoint_dir, eval=False, strict=True): # pylint: disable=unused-argument, redefined-builtin
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
Performs some changes for compatibility.
@@ -1589,7 +1587,7 @@ def load_fairseq_checkpoint(
checkpoint_file = checkpoint_dir / "G_100000.pth"
vocab_file = checkpoint_dir / "vocab.txt"
# set config params
- with open(config_file, "r", encoding="utf-8") as f:
+ with open(config_file, encoding="utf-8") as f:
# Load the JSON data as a dictionary
config_org = json.load(f)
self.config.audio.sample_rate = config_org["data"]["sampling_rate"]
@@ -1613,7 +1611,7 @@ def load_fairseq_checkpoint(
assert not self.training
@staticmethod
- def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None):
+ def init_from_config(config: "VitsConfig", samples: list[list] | list[dict] = None):
"""Initiate model from config
Args:
@@ -1626,15 +1624,15 @@ def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]
upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item()
if not config.model_args.encoder_sample_rate:
- assert (
- upsample_rate == config.audio.hop_length
- ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}"
+ assert upsample_rate == config.audio.hop_length, (
+ f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}"
+ )
else:
encoder_to_vocoder_upsampling_factor = config.audio.sample_rate / config.model_args.encoder_sample_rate
effective_hop_length = config.audio.hop_length * encoder_to_vocoder_upsampling_factor
- assert (
- upsample_rate == effective_hop_length
- ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}"
+ assert upsample_rate == effective_hop_length, (
+ f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}"
+ )
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
@@ -1825,7 +1823,7 @@ def to_config(self) -> "CharactersConfig":
class FairseqVocab(BaseVocabulary):
- def __init__(self, vocab: Union[str, os.PathLike[Any]]):
+ def __init__(self, vocab: str | os.PathLike[Any]):
super(FairseqVocab).__init__()
self.vocab = vocab
@@ -1835,7 +1833,7 @@ def vocab(self):
return self._vocab
@vocab.setter
- def vocab(self, vocab_file: Union[str, os.PathLike[Any]]):
+ def vocab(self, vocab_file: str | os.PathLike[Any]):
with open(vocab_file, encoding="utf-8") as f:
self._vocab = [x.replace("\n", "") for x in f.readlines()]
self.blank = self._vocab[0]
diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py
index c0a50800f6..2df07a0435 100644
--- a/TTS/tts/models/xtts.py
+++ b/TTS/tts/models/xtts.py
@@ -2,7 +2,6 @@
import os
from dataclasses import dataclass
from pathlib import Path
-from typing import Optional
import librosa
import torch
@@ -380,9 +379,9 @@ def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwa
as latents used at inference.
"""
- assert (
- "zh-cn" if language == "zh" else language in self.config.languages
- ), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}"
+ assert "zh-cn" if language == "zh" else language in self.config.languages, (
+ f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}"
+ )
# Use generally found best tuning knobs for generation.
settings = {
"temperature": config.temperature,
@@ -520,9 +519,9 @@ def inference(
sent = sent.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
- assert (
- text_tokens.shape[-1] < self.args.gpt_max_text_tokens
- ), " ❗ XTTS can only generate text with a maximum of 400 tokens."
+ assert text_tokens.shape[-1] < self.args.gpt_max_text_tokens, (
+ " ❗ XTTS can only generate text with a maximum of 400 tokens."
+ )
with torch.no_grad():
gpt_codes = self.gpt.generate(
@@ -628,9 +627,9 @@ def inference_stream(
sent = sent.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
- assert (
- text_tokens.shape[-1] < self.args.gpt_max_text_tokens
- ), " ❗ XTTS can only generate text with a maximum of 400 tokens."
+ assert text_tokens.shape[-1] < self.args.gpt_max_text_tokens, (
+ " ❗ XTTS can only generate text with a maximum of 400 tokens."
+ )
fake_inputs = self.gpt.compute_embeddings(
gpt_cond_latent.to(self.device),
@@ -719,13 +718,13 @@ def get_compatible_checkpoint_state_dict(self, model_path):
def load_checkpoint(
self,
config: "XttsConfig",
- checkpoint_dir: Optional[str] = None,
- checkpoint_path: Optional[str] = None,
- vocab_path: Optional[str] = None,
+ checkpoint_dir: str | None = None,
+ checkpoint_path: str | None = None,
+ vocab_path: str | None = None,
eval: bool = True,
strict: bool = True,
use_deepspeed: bool = False,
- speaker_file_path: Optional[str] = None,
+ speaker_file_path: str | None = None,
):
"""
Loads a checkpoint from disk and initializes the model's state and tokenizer.
diff --git a/TTS/tts/utils/data.py b/TTS/tts/utils/data.py
index 22e46b683a..d0269060c8 100644
--- a/TTS/tts/utils/data.py
+++ b/TTS/tts/utils/data.py
@@ -11,7 +11,7 @@ def _pad_data(x, length):
def prepare_data(inputs):
- max_len = max((len(x) for x in inputs))
+ max_len = max(len(x) for x in inputs)
return np.stack([_pad_data(x, max_len) for x in inputs])
@@ -23,7 +23,7 @@ def _pad_tensor(x, length):
def prepare_tensor(inputs, out_steps):
- max_len = max((x.shape[1] for x in inputs))
+ max_len = max(x.shape[1] for x in inputs)
remainder = max_len % out_steps
pad_len = max_len + (out_steps - remainder) if remainder > 0 else max_len
return np.stack([_pad_tensor(x, pad_len) for x in inputs])
@@ -46,7 +46,7 @@ def _pad_stop_target(x: np.ndarray, length: int, pad_val=1) -> np.ndarray:
def prepare_stop_target(inputs, out_steps):
"""Pad row vectors with 1."""
- max_len = max((x.shape[0] for x in inputs))
+ max_len = max(x.shape[0] for x in inputs)
remainder = max_len % out_steps
pad_len = max_len + (out_steps - remainder) if remainder > 0 else max_len
return np.stack([_pad_stop_target(x, pad_len) for x in inputs])
diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py
index ff10f751f2..a3648eff4b 100644
--- a/TTS/tts/utils/helpers.py
+++ b/TTS/tts/utils/helpers.py
@@ -1,5 +1,3 @@
-from typing import Optional
-
import numpy as np
import torch
from scipy.stats import betabinom
@@ -35,7 +33,7 @@ def inverse_transform(self, X):
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
-def sequence_mask(sequence_length: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor:
+def sequence_mask(sequence_length: torch.Tensor, max_len: int | None = None) -> torch.Tensor:
"""Create a sequence mask for filtering padding in a sequence tensor.
Args:
@@ -107,9 +105,9 @@ def rand_segments(
_x_lenghts[len_diff < 0] = segment_size
len_diff = _x_lenghts - segment_size
else:
- assert all(
- len_diff > 0
- ), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
+ assert all(len_diff > 0), (
+ f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
+ )
segment_indices = (torch.rand([B]).type_as(x) * (len_diff + 1)).long()
ret = segment(x, segment_indices, segment_size, pad_short=pad_short)
return ret, segment_indices
@@ -164,7 +162,7 @@ def generate_path(duration: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
def generate_attention(
- duration: torch.Tensor, x_mask: torch.Tensor, y_mask: Optional[torch.Tensor] = None
+ duration: torch.Tensor, x_mask: torch.Tensor, y_mask: torch.Tensor | None = None
) -> torch.Tensor:
"""Generate an attention map from the linear scale durations.
diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py
index c72de2d4e6..5ce7759dd8 100644
--- a/TTS/tts/utils/languages.py
+++ b/TTS/tts/utils/languages.py
@@ -1,5 +1,5 @@
import os
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Optional
import fsspec
import numpy as np
@@ -27,8 +27,8 @@ class LanguageManager(BaseIDManager):
def __init__(
self,
- language_ids_file_path: Union[str, os.PathLike[Any]] = "",
- config: Optional[Coqpit] = None,
+ language_ids_file_path: str | os.PathLike[Any] = "",
+ config: Coqpit | None = None,
):
super().__init__(id_file_path=language_ids_file_path)
@@ -40,11 +40,11 @@ def num_languages(self) -> int:
return len(list(self.name_to_id.keys()))
@property
- def language_names(self) -> List:
+ def language_names(self) -> list:
return list(self.name_to_id.keys())
@staticmethod
- def parse_language_ids_from_config(c: Coqpit) -> Dict:
+ def parse_language_ids_from_config(c: Coqpit) -> dict:
"""Set language id from config.
Args:
@@ -70,13 +70,13 @@ def set_language_ids_from_config(self, c: Coqpit) -> None:
self.name_to_id = self.parse_language_ids_from_config(c)
@staticmethod
- def parse_ids_from_data(items: List, parse_key: str) -> Any:
+ def parse_ids_from_data(items: list, parse_key: str) -> Any:
raise NotImplementedError
- def set_ids_from_data(self, items: List, parse_key: str) -> Any:
+ def set_ids_from_data(self, items: list, parse_key: str) -> Any:
raise NotImplementedError
- def save_ids_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
+ def save_ids_to_file(self, file_path: str | os.PathLike[Any]) -> None:
"""Save language IDs to a json file.
Args:
diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py
index e009a7c438..49e93454f2 100644
--- a/TTS/tts/utils/managers.py
+++ b/TTS/tts/utils/managers.py
@@ -1,7 +1,7 @@
import json
import os
import random
-from typing import Any, Dict, List, Tuple, Union
+from typing import Any
import fsspec
import numpy as np
@@ -13,7 +13,7 @@
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
-def load_file(path: Union[str, os.PathLike[Any]]):
+def load_file(path: str | os.PathLike[Any]):
path = str(path)
if path.endswith(".json"):
with fsspec.open(path, "r") as f:
@@ -25,7 +25,7 @@ def load_file(path: Union[str, os.PathLike[Any]]):
raise ValueError("Unsupported file type")
-def save_file(obj: Any, path: Union[str, os.PathLike[Any]]):
+def save_file(obj: Any, path: str | os.PathLike[Any]):
path = str(path)
if path.endswith(".json"):
with fsspec.open(path, "w") as f:
@@ -42,23 +42,23 @@ class BaseIDManager:
It defines common `ID` manager specific functions.
"""
- def __init__(self, id_file_path: Union[str, os.PathLike[Any]] = ""):
+ def __init__(self, id_file_path: str | os.PathLike[Any] = ""):
self.name_to_id = {}
if id_file_path:
self.load_ids_from_file(id_file_path)
@staticmethod
- def _load_json(json_file_path: Union[str, os.PathLike[Any]]) -> Dict:
+ def _load_json(json_file_path: str | os.PathLike[Any]) -> dict:
with fsspec.open(str(json_file_path), "r") as f:
return json.load(f)
@staticmethod
- def _save_json(json_file_path: Union[str, os.PathLike[Any]], data: dict) -> None:
+ def _save_json(json_file_path: str | os.PathLike[Any], data: dict) -> None:
with fsspec.open(str(json_file_path), "w") as f:
json.dump(data, f, indent=4)
- def set_ids_from_data(self, items: List, parse_key: str) -> None:
+ def set_ids_from_data(self, items: list, parse_key: str) -> None:
"""Set IDs from data samples.
Args:
@@ -66,7 +66,7 @@ def set_ids_from_data(self, items: List, parse_key: str) -> None:
"""
self.name_to_id = self.parse_ids_from_data(items, parse_key=parse_key)
- def load_ids_from_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
+ def load_ids_from_file(self, file_path: str | os.PathLike[Any]) -> None:
"""Set IDs from a file.
Args:
@@ -74,7 +74,7 @@ def load_ids_from_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
"""
self.name_to_id = load_file(file_path)
- def save_ids_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
+ def save_ids_to_file(self, file_path: str | os.PathLike[Any]) -> None:
"""Save IDs to a json file.
Args:
@@ -96,7 +96,7 @@ def get_random_id(self) -> Any:
return None
@staticmethod
- def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]:
+ def parse_ids_from_data(items: list, parse_key: str) -> tuple[dict]:
"""Parse IDs from data samples retured by `load_tts_samples()`.
Args:
@@ -133,10 +133,10 @@ class EmbeddingManager(BaseIDManager):
def __init__(
self,
- embedding_file_path: Union[Union[str, os.PathLike[Any]], list[Union[str, os.PathLike[Any]]]] = "",
- id_file_path: Union[str, os.PathLike[Any]] = "",
- encoder_model_path: Union[str, os.PathLike[Any]] = "",
- encoder_config_path: Union[str, os.PathLike[Any]] = "",
+ embedding_file_path: str | os.PathLike[Any] | list[str | os.PathLike[Any]] = "",
+ id_file_path: str | os.PathLike[Any] = "",
+ encoder_model_path: str | os.PathLike[Any] = "",
+ encoder_config_path: str | os.PathLike[Any] = "",
use_cuda: bool = False,
):
super().__init__(id_file_path=id_file_path)
@@ -179,7 +179,7 @@ def embedding_names(self):
"""Get embedding names."""
return list(self.embeddings_by_names.keys())
- def save_embeddings_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
+ def save_embeddings_to_file(self, file_path: str | os.PathLike[Any]) -> None:
"""Save embeddings to a json file.
Args:
@@ -188,7 +188,7 @@ def save_embeddings_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> No
save_file(self.embeddings, file_path)
@staticmethod
- def read_embeddings_from_file(file_path: Union[str, os.PathLike[Any]]):
+ def read_embeddings_from_file(file_path: str | os.PathLike[Any]):
"""Load embeddings from a json file.
Args:
@@ -207,7 +207,7 @@ def read_embeddings_from_file(file_path: Union[str, os.PathLike[Any]]):
embeddings_by_names[x["name"]].append(x["embedding"])
return name_to_id, clip_ids, embeddings, embeddings_by_names
- def load_embeddings_from_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
+ def load_embeddings_from_file(self, file_path: str | os.PathLike[Any]) -> None:
"""Load embeddings from a json file.
Args:
@@ -217,7 +217,7 @@ def load_embeddings_from_file(self, file_path: Union[str, os.PathLike[Any]]) ->
file_path
)
- def load_embeddings_from_list_of_files(self, file_paths: list[Union[str, os.PathLike[Any]]]) -> None:
+ def load_embeddings_from_list_of_files(self, file_paths: list[str | os.PathLike[Any]]) -> None:
"""Load embeddings from a list of json files and don't allow duplicate keys.
Args:
@@ -242,7 +242,7 @@ def load_embeddings_from_list_of_files(self, file_paths: list[Union[str, os.Path
# reset name_to_id to get the right speaker ids
self.name_to_id = {name: i for i, name in enumerate(self.name_to_id)}
- def get_embedding_by_clip(self, clip_idx: str) -> List:
+ def get_embedding_by_clip(self, clip_idx: str) -> list:
"""Get embedding by clip ID.
Args:
@@ -253,7 +253,7 @@ def get_embedding_by_clip(self, clip_idx: str) -> List:
"""
return self.embeddings[clip_idx]["embedding"]
- def get_embeddings_by_name(self, idx: str) -> List[List]:
+ def get_embeddings_by_name(self, idx: str) -> list[list]:
"""Get all embeddings of a speaker.
Args:
@@ -264,7 +264,7 @@ def get_embeddings_by_name(self, idx: str) -> List[List]:
"""
return self.embeddings_by_names[idx]
- def get_embeddings_by_names(self) -> Dict:
+ def get_embeddings_by_names(self) -> dict:
"""Get all embeddings by names.
Returns:
@@ -313,11 +313,11 @@ def get_random_embedding(self) -> Any:
return None
- def get_clips(self) -> List:
+ def get_clips(self) -> list:
return sorted(self.embeddings.keys())
def init_encoder(
- self, model_path: Union[str, os.PathLike[Any]], config_path: Union[str, os.PathLike[Any]], use_cuda=False
+ self, model_path: str | os.PathLike[Any], config_path: str | os.PathLike[Any], use_cuda=False
) -> None:
"""Initialize a speaker encoder model.
@@ -335,9 +335,7 @@ def init_encoder(
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)
@torch.inference_mode()
- def compute_embedding_from_clip(
- self, wav_file: Union[Union[str, os.PathLike[Any]], List[Union[str, os.PathLike[Any]]]]
- ) -> list:
+ def compute_embedding_from_clip(self, wav_file: str | os.PathLike[Any] | list[str | os.PathLike[Any]]) -> list:
"""Compute a embedding from a given audio file.
Args:
@@ -374,7 +372,7 @@ def _compute(wav_file: str):
embedding = _compute(wav_file)
return embedding[0].tolist()
- def compute_embeddings(self, feats: Union[torch.Tensor, np.ndarray]) -> List:
+ def compute_embeddings(self, feats: torch.Tensor | np.ndarray) -> list:
"""Compute embedding from features.
Args:
diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py
index 89c56583f5..6fab27de5a 100644
--- a/TTS/tts/utils/speakers.py
+++ b/TTS/tts/utils/speakers.py
@@ -1,7 +1,7 @@
import json
import logging
import os
-from typing import Any, Dict, List, Optional, Union
+from typing import Any
import fsspec
import numpy as np
@@ -56,11 +56,11 @@ class SpeakerManager(EmbeddingManager):
def __init__(
self,
- data_items: Optional[list[list[Any]]] = None,
+ data_items: list[list[Any]] | None = None,
d_vectors_file_path: str = "",
- speaker_id_file_path: Union[str, os.PathLike[Any]] = "",
- encoder_model_path: Union[str, os.PathLike[Any]] = "",
- encoder_config_path: Union[str, os.PathLike[Any]] = "",
+ speaker_id_file_path: str | os.PathLike[Any] = "",
+ encoder_model_path: str | os.PathLike[Any] = "",
+ encoder_config_path: str | os.PathLike[Any] = "",
use_cuda: bool = False,
):
super().__init__(
@@ -82,11 +82,11 @@ def num_speakers(self):
def speaker_names(self):
return list(self.name_to_id.keys())
- def get_speakers(self) -> List:
+ def get_speakers(self) -> list:
return self.name_to_id
@staticmethod
- def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager":
+ def init_from_config(config: "Coqpit", samples: list[list] | list[dict] = None) -> "SpeakerManager":
"""Initialize a speaker manager from config
Args:
@@ -150,7 +150,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
json.dump(speaker_mapping, f, indent=4)
-def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, out_path: str = None) -> SpeakerManager:
+def get_speaker_manager(c: Coqpit, data: list = None, restore_path: str = None, out_path: str = None) -> SpeakerManager:
"""Initiate a `SpeakerManager` instance by the provided config.
Args:
@@ -185,9 +185,9 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
elif not c.use_d_vector_file: # restor speaker manager with speaker ID file.
speaker_ids_from_data = speaker_manager.name_to_id
speaker_manager.load_ids_from_file(speakers_file)
- assert all(
- speaker in speaker_manager.name_to_id for speaker in speaker_ids_from_data
- ), " [!] You cannot introduce new speakers to a pre-trained model."
+ assert all(speaker in speaker_manager.name_to_id for speaker in speaker_ids_from_data), (
+ " [!] You cannot introduce new speakers to a pre-trained model."
+ )
elif c.use_d_vector_file and c.d_vector_file:
# new speaker manager with external speaker embeddings.
speaker_manager.load_embeddings_from_file(c.d_vector_file)
diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py
index eddf05db3f..660370a832 100644
--- a/TTS/tts/utils/ssim.py
+++ b/TTS/tts/utils/ssim.py
@@ -1,6 +1,5 @@
# Adopted from https://github.com/photosynthesis-team/piq
-from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@@ -24,11 +23,11 @@ def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
def _validate_input(
- tensors: List[torch.Tensor],
- dim_range: Tuple[int, int] = (0, -1),
- data_range: Tuple[float, float] = (0.0, -1.0),
+ tensors: list[torch.Tensor],
+ dim_range: tuple[int, int] = (0, -1),
+ data_range: tuple[float, float] = (0.0, -1.0),
# size_dim_range: Tuple[float, float] = (0., -1.),
- size_range: Optional[Tuple[int, int]] = None,
+ size_range: tuple[int, int] | None = None,
) -> None:
r"""Check that input(-s) satisfies the requirements
Args:
@@ -50,16 +49,16 @@ def _validate_input(
if size_range is None:
assert t.size() == x.size(), f"Expected tensors with same size, got {t.size()} and {x.size()}"
else:
- assert (
- t.size()[size_range[0] : size_range[1]] == x.size()[size_range[0] : size_range[1]]
- ), f"Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}"
+ assert t.size()[size_range[0] : size_range[1]] == x.size()[size_range[0] : size_range[1]], (
+ f"Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}"
+ )
if dim_range[0] == dim_range[1]:
assert t.dim() == dim_range[0], f"Expected number of dimensions to be {dim_range[0]}, got {t.dim()}"
elif dim_range[0] < dim_range[1]:
- assert (
- dim_range[0] <= t.dim() <= dim_range[1]
- ), f"Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}"
+ assert dim_range[0] <= t.dim() <= dim_range[1], (
+ f"Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}"
+ )
if data_range[0] < data_range[1]:
assert data_range[0] <= t.min(), f"Expected values to be greater or equal to {data_range[0]}, got {t.min()}"
@@ -89,13 +88,13 @@ def ssim(
y: torch.Tensor,
kernel_size: int = 11,
kernel_sigma: float = 1.5,
- data_range: Union[int, float] = 1.0,
+ data_range: int | float = 1.0,
reduction: str = "mean",
full: bool = False,
downsample: bool = True,
k1: float = 0.01,
k2: float = 0.03,
-) -> List[torch.Tensor]:
+) -> list[torch.Tensor]:
r"""Interface of Structural Similarity (SSIM) index.
Inputs supposed to be in range ``[0, data_range]``.
To match performance with skimage and tensorflow set ``'downsample' = True``.
@@ -218,7 +217,7 @@ def __init__(
k2: float = 0.03,
downsample: bool = True,
reduction: str = "mean",
- data_range: Union[int, float] = 1.0,
+ data_range: int | float = 1.0,
) -> None:
super().__init__()
@@ -270,7 +269,7 @@ def _ssim_per_channel(
kernel: torch.Tensor,
k1: float = 0.01,
k2: float = 0.03,
-) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
r"""Calculate Structural Similarity (SSIM) index for X and Y per channel.
Args:
@@ -286,8 +285,7 @@ def _ssim_per_channel(
"""
if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-2):
raise ValueError(
- f"Kernel size can't be greater than actual input size. Input size: {x.size()}. "
- f"Kernel size: {kernel.size()}"
+ f"Kernel size can't be greater than actual input size. Input size: {x.size()}. Kernel size: {kernel.size()}"
)
c1 = k1**2
@@ -321,7 +319,7 @@ def _ssim_per_channel_complex(
kernel: torch.Tensor,
k1: float = 0.01,
k2: float = 0.03,
-) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
r"""Calculate Structural Similarity (SSIM) index for Complex X and Y per channel.
Args:
@@ -338,8 +336,7 @@ def _ssim_per_channel_complex(
n_channels = x.size(1)
if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-2):
raise ValueError(
- f"Kernel size can't be greater than actual input size. Input size: {x.size()}. "
- f"Kernel size: {kernel.size()}"
+ f"Kernel size can't be greater than actual input size. Input size: {x.size()}. Kernel size: {kernel.size()}"
)
c1 = k1**2
diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py
index 5dc4cc569f..c09c3f5aa2 100644
--- a/TTS/tts/utils/synthesis.py
+++ b/TTS/tts/utils/synthesis.py
@@ -1,13 +1,9 @@
-from typing import Dict, Optional, Union
-
import numpy as np
import torch
from torch import nn
-def numpy_to_torch(
- np_array: np.ndarray, dtype: torch.dtype, device: Union[str, torch.device] = "cpu"
-) -> Optional[torch.Tensor]:
+def numpy_to_torch(np_array: np.ndarray, dtype: torch.dtype, device: str | torch.device = "cpu") -> torch.Tensor | None:
if np_array is None:
return None
return torch.as_tensor(np_array, dtype=dtype, device=device)
@@ -31,7 +27,7 @@ def run_model_torch(
style_text: str = None,
d_vector: torch.Tensor = None,
language_id: torch.Tensor = None,
-) -> Dict:
+) -> dict:
"""Run a torch model for inference. It does not support batch inference.
Args:
@@ -75,14 +71,14 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
return wav
-def id_to_torch(aux_id, device: Union[str, torch.device] = "cpu") -> Optional[torch.Tensor]:
+def id_to_torch(aux_id, device: str | torch.device = "cpu") -> torch.Tensor | None:
if aux_id is not None:
aux_id = np.asarray(aux_id)
aux_id = torch.from_numpy(aux_id).to(device)
return aux_id
-def embedding_to_torch(d_vector, device: Union[str, torch.device] = "cpu") -> Optional[torch.Tensor]:
+def embedding_to_torch(d_vector, device: str | torch.device = "cpu") -> torch.Tensor | None:
if d_vector is not None:
d_vector = np.asarray(d_vector)
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
diff --git a/TTS/tts/utils/text/bangla/phonemizer.py b/TTS/tts/utils/text/bangla/phonemizer.py
index cddcb00fd5..1537240380 100644
--- a/TTS/tts/utils/text/bangla/phonemizer.py
+++ b/TTS/tts/utils/text/bangla/phonemizer.py
@@ -45,7 +45,7 @@ def tag_text(text: str):
# create start and end
text = "start" + text + "end"
# tag text
- parts = re.split("[\u0600-\u06FF]+", text)
+ parts = re.split("[\u0600-\u06ff]+", text)
# remove non chars
parts = [p for p in parts if p.strip()]
# unique parts
diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py
index 4bf9bf6bd5..f8beaef036 100644
--- a/TTS/tts/utils/text/characters.py
+++ b/TTS/tts/utils/text/characters.py
@@ -1,6 +1,5 @@
import logging
from dataclasses import replace
-from typing import Dict
from TTS.tts.configs.shared_configs import CharactersConfig
@@ -47,7 +46,7 @@ class BaseVocabulary:
vocab (Dict): A dictionary of characters and their corresponding indices.
"""
- def __init__(self, vocab: Dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None):
+ def __init__(self, vocab: dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None):
self.vocab = vocab
self.pad = pad
self.blank = blank
@@ -290,9 +289,9 @@ def _create_vocab(self):
self.vocab = _vocab + list(self._punctuations)
if self.is_unique:
duplicates = {x for x in self.vocab if self.vocab.count(x) > 1}
- assert (
- len(self.vocab) == len(self._char_to_id) == len(self._id_to_char)
- ), f" [!] There are duplicate characters in the character set. {duplicates}"
+ assert len(self.vocab) == len(self._char_to_id) == len(self._id_to_char), (
+ f" [!] There are duplicate characters in the character set. {duplicates}"
+ )
def char_to_id(self, char: str) -> int:
try:
diff --git a/TTS/tts/utils/text/chinese_mandarin/numbers.py b/TTS/tts/utils/text/chinese_mandarin/numbers.py
index 4787ea6100..3e6a043918 100644
--- a/TTS/tts/utils/text/chinese_mandarin/numbers.py
+++ b/TTS/tts/utils/text/chinese_mandarin/numbers.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
# Licensed under WTFPL or the Unlicense or CC0.
# This uses Python 3, but it's easy to port to Python 2 by changing
diff --git a/TTS/tts/utils/text/chinese_mandarin/phonemizer.py b/TTS/tts/utils/text/chinese_mandarin/phonemizer.py
index e9d62e9d06..4dccdd5778 100644
--- a/TTS/tts/utils/text/chinese_mandarin/phonemizer.py
+++ b/TTS/tts/utils/text/chinese_mandarin/phonemizer.py
@@ -1,5 +1,3 @@
-from typing import List
-
try:
import jieba
import pypinyin
@@ -9,7 +7,7 @@
from .pinyinToPhonemes import PINYIN_DICT
-def _chinese_character_to_pinyin(text: str) -> List[str]:
+def _chinese_character_to_pinyin(text: str) -> list[str]:
pinyins = pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)
pinyins_flat_list = [item for sublist in pinyins for item in sublist]
return pinyins_flat_list
@@ -25,9 +23,9 @@ def _chinese_pinyin_to_phoneme(pinyin: str) -> str:
def chinese_text_to_phonemes(text: str, seperator: str = "|") -> str:
tokenized_text = jieba.cut(text, HMM=False)
tokenized_text = " ".join(tokenized_text)
- pinyined_text: List[str] = _chinese_character_to_pinyin(tokenized_text)
+ pinyined_text: list[str] = _chinese_character_to_pinyin(tokenized_text)
- results: List[str] = []
+ results: list[str] = []
for token in pinyined_text:
if token[-1] in "12345": # TODO transform to is_pinyin()
diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py
index f496b9f0dd..795ab246d2 100644
--- a/TTS/tts/utils/text/cleaners.py
+++ b/TTS/tts/utils/text/cleaners.py
@@ -1,7 +1,6 @@
"""Set of default text cleaners"""
import re
-from typing import Optional
from unicodedata import normalize
from anyascii import anyascii
@@ -47,7 +46,7 @@ def remove_aux_symbols(text: str) -> str:
return text
-def replace_symbols(text: str, lang: Optional[str] = "en") -> str:
+def replace_symbols(text: str, lang: str | None = "en") -> str:
"""Replace symbols based on the language tag.
Args:
diff --git a/TTS/tts/utils/text/cmudict.py b/TTS/tts/utils/text/cmudict.py
index f206fb043b..9c0df06196 100644
--- a/TTS/tts/utils/text/cmudict.py
+++ b/TTS/tts/utils/text/cmudict.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
import re
VALID_SYMBOLS = [
@@ -121,7 +119,7 @@ def get_arpabet(word, cmudict, punctuation_symbols):
word = word[:-1]
arpabet = cmudict.lookup(word)
if arpabet is not None:
- return first_symbol + "{%s}" % arpabet[0] + last_symbol
+ return first_symbol + "{%s}" % arpabet[0] + last_symbol # noqa: UP031
return first_symbol + word + last_symbol
diff --git a/TTS/tts/utils/text/english/abbreviations.py b/TTS/tts/utils/text/english/abbreviations.py
index cd93c13c8e..20042b255b 100644
--- a/TTS/tts/utils/text/english/abbreviations.py
+++ b/TTS/tts/utils/text/english/abbreviations.py
@@ -2,7 +2,7 @@
# List of (regular expression, replacement) pairs for abbreviations in english:
abbreviations_en = [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
diff --git a/TTS/tts/utils/text/english/number_norm.py b/TTS/tts/utils/text/english/number_norm.py
index c912e285e4..be2a4b3084 100644
--- a/TTS/tts/utils/text/english/number_norm.py
+++ b/TTS/tts/utils/text/english/number_norm.py
@@ -1,7 +1,6 @@
-""" from https://github.com/keithito/tacotron """
+"""from https://github.com/keithito/tacotron"""
import re
-from typing import Dict
import inflect
@@ -21,7 +20,7 @@ def _expand_decimal_point(m):
return m.group(1).replace(".", " point ")
-def __expand_currency(value: str, inflection: Dict[float, str]) -> str:
+def __expand_currency(value: str, inflection: dict[float, str]) -> str:
parts = value.replace(",", "").split(".")
if len(parts) > 2:
return f"{value} {inflection[2]}" # Unexpected format
diff --git a/TTS/tts/utils/text/french/abbreviations.py b/TTS/tts/utils/text/french/abbreviations.py
index f580dfed7b..e317bbbf3a 100644
--- a/TTS/tts/utils/text/french/abbreviations.py
+++ b/TTS/tts/utils/text/french/abbreviations.py
@@ -2,7 +2,7 @@
# List of (regular expression, replacement) pairs for abbreviations in french:
abbreviations_fr = [
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("M", "monsieur"),
("Mlle", "mademoiselle"),
@@ -38,7 +38,7 @@
("boul", "boulevard"),
]
] + [
- (re.compile("\\b%s" % x[0]), x[1])
+ (re.compile(f"\\b{x[0]}"), x[1])
for x in [
("Mlle", "mademoiselle"),
("Mlles", "mesdemoiselles"),
diff --git a/TTS/tts/utils/text/korean/ko_dictionary.py b/TTS/tts/utils/text/korean/ko_dictionary.py
index 9b739339c6..706f9f5daf 100644
--- a/TTS/tts/utils/text/korean/ko_dictionary.py
+++ b/TTS/tts/utils/text/korean/ko_dictionary.py
@@ -1,4 +1,3 @@
-# coding: utf-8
# Add the word you want to the dictionary.
etc_dictionary = {"1+1": "원플러스원", "2+1": "투플러스원"}
diff --git a/TTS/tts/utils/text/korean/korean.py b/TTS/tts/utils/text/korean/korean.py
index 423aeed377..1b1e0ca0fb 100644
--- a/TTS/tts/utils/text/korean/korean.py
+++ b/TTS/tts/utils/text/korean/korean.py
@@ -1,4 +1,3 @@
-# coding: utf-8
# Code based on https://github.com/carpedm20/multi-speaker-tacotron-tensorflow/blob/master/text/korean.py
import re
diff --git a/TTS/tts/utils/text/phonemizers/bangla_phonemizer.py b/TTS/tts/utils/text/phonemizers/bangla_phonemizer.py
index 3c4a35bbfa..3be7354636 100644
--- a/TTS/tts/utils/text/phonemizers/bangla_phonemizer.py
+++ b/TTS/tts/utils/text/phonemizers/bangla_phonemizer.py
@@ -1,5 +1,3 @@
-from typing import Dict
-
from TTS.tts.utils.text.bangla.phonemizer import bangla_text_to_phonemes
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
@@ -41,7 +39,7 @@ def _phonemize(self, text, separator):
return self.phonemize_bn(text, separator)
@staticmethod
- def supported_languages() -> Dict:
+ def supported_languages() -> dict:
return {"bn": "Bangla"}
def version(self) -> str:
diff --git a/TTS/tts/utils/text/phonemizers/base.py b/TTS/tts/utils/text/phonemizers/base.py
index 5e701df458..6cc6ec0b37 100644
--- a/TTS/tts/utils/text/phonemizers/base.py
+++ b/TTS/tts/utils/text/phonemizers/base.py
@@ -1,6 +1,5 @@
import abc
import logging
-from typing import List, Tuple
from TTS.tts.utils.text.punctuation import Punctuation
@@ -37,7 +36,7 @@ class BasePhonemizer(abc.ABC):
def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False):
# ensure the backend is installed on the system
if not self.is_available():
- raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover
+ raise RuntimeError(f"{self.name()} not installed on your system") # pragma: nocover
# ensure the backend support the requested language
self._language = self._init_language(language)
@@ -53,7 +52,7 @@ def _init_language(self, language):
"""
if not self.is_supported_language(language):
- raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend")
+ raise RuntimeError(f'language "{language}" is not supported by the {self.name()} backend')
return language
@property
@@ -93,7 +92,7 @@ def is_supported_language(self, language):
def _phonemize(self, text, separator):
"""The main phonemization method"""
- def _phonemize_preprocess(self, text) -> Tuple[List[str], List]:
+ def _phonemize_preprocess(self, text) -> tuple[list[str], list]:
"""Preprocess the text before phonemization
1. remove spaces
diff --git a/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py b/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py
index e5fcab6e09..fa4a515d1a 100644
--- a/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py
+++ b/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py
@@ -1,5 +1,3 @@
-from typing import Dict
-
from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
@@ -34,7 +32,7 @@ def _phonemize(self, text, separator):
return self.phonemize_be(text, separator)
@staticmethod
- def supported_languages() -> Dict:
+ def supported_languages() -> dict:
return {"be": "Belarusian"}
def version(self) -> str:
diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py
index a15df716e7..dbcb8994a7 100644
--- a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py
+++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py
@@ -5,7 +5,6 @@
import subprocess
import tempfile
from pathlib import Path
-from typing import Optional
from packaging.version import Version
@@ -104,7 +103,7 @@ class ESpeak(BasePhonemizer):
def __init__(
self,
language: str,
- backend: Optional[str] = None,
+ backend: str | None = None,
punctuations: str = Punctuation.default_puncs(),
keep_puncs: bool = True,
):
@@ -184,7 +183,7 @@ def phonemize_espeak(self, text: str, separator: str = "|", *, tie: bool = False
else:
args.append("--ipa=1")
if tie:
- args.append("--tie=%s" % tie)
+ args.append(f"--tie={tie}")
tmp = tempfile.NamedTemporaryFile(mode="w+t", delete=False, encoding="utf8")
tmp.write(text)
diff --git a/TTS/tts/utils/text/phonemizers/gruut_wrapper.py b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py
index f3e9c9abd4..836fccf5b8 100644
--- a/TTS/tts/utils/text/phonemizers/gruut_wrapper.py
+++ b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py
@@ -1,5 +1,4 @@
import importlib
-from typing import List
import gruut
from gruut_ipa import IPA
@@ -114,7 +113,7 @@ def is_supported_language(self, language):
return gruut.is_language_supported(language)
@staticmethod
- def supported_languages() -> List:
+ def supported_languages() -> list:
"""Get a dictionary of supported languages.
Returns:
diff --git a/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py
index 878e5e5296..b3b3ba4db7 100644
--- a/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py
+++ b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py
@@ -1,5 +1,3 @@
-from typing import Dict
-
from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
@@ -51,7 +49,7 @@ def phonemize(self, text: str, separator="|", language=None) -> str:
return self._phonemize(text, separator)
@staticmethod
- def supported_languages() -> Dict:
+ def supported_languages() -> dict:
return {"ja-jp": "Japanese (Japan)"}
def version(self) -> str:
diff --git a/TTS/tts/utils/text/phonemizers/ko_kr_phonemizer.py b/TTS/tts/utils/text/phonemizers/ko_kr_phonemizer.py
index 0bdba2137b..93930d064e 100644
--- a/TTS/tts/utils/text/phonemizers/ko_kr_phonemizer.py
+++ b/TTS/tts/utils/text/phonemizers/ko_kr_phonemizer.py
@@ -1,5 +1,3 @@
-from typing import Dict
-
from TTS.tts.utils.text.korean.phonemizer import korean_text_to_phonemes
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
@@ -44,7 +42,7 @@ def phonemize(self, text: str, separator: str = "", character: str = "hangeul",
return self._phonemize(text, separator, character)
@staticmethod
- def supported_languages() -> Dict:
+ def supported_languages() -> dict:
return {"ko-kr": "hangeul(korean)"}
def version(self) -> str:
diff --git a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py
index 1a9e98b091..87fb940f6b 100644
--- a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py
+++ b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py
@@ -1,5 +1,4 @@
import logging
-from typing import Dict, List
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
@@ -19,7 +18,7 @@ class MultiPhonemizer:
lang_to_phonemizer = {}
- def __init__(self, lang_to_phonemizer_name: Dict = {}) -> None: # pylint: disable=dangerous-default-value
+ def __init__(self, lang_to_phonemizer_name: dict = {}) -> None: # pylint: disable=dangerous-default-value
for k, v in lang_to_phonemizer_name.items():
if v == "" and k in DEF_LANG_TO_PHONEMIZER.keys():
lang_to_phonemizer_name[k] = DEF_LANG_TO_PHONEMIZER[k]
@@ -29,7 +28,7 @@ def __init__(self, lang_to_phonemizer_name: Dict = {}) -> None: # pylint: disab
self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name)
@staticmethod
- def init_phonemizers(lang_to_phonemizer_name: Dict) -> Dict:
+ def init_phonemizers(lang_to_phonemizer_name: dict) -> dict:
lang_to_phonemizer = {}
for k, v in lang_to_phonemizer_name.items():
lang_to_phonemizer[k] = get_phonemizer_by_name(v, language=k)
@@ -44,7 +43,7 @@ def phonemize(self, text, separator="|", language=""):
raise ValueError("Language must be set for multi-phonemizer to phonemize.")
return self.lang_to_phonemizer[language].phonemize(text, separator)
- def supported_languages(self) -> List:
+ def supported_languages(self) -> list:
return list(self.lang_to_phonemizer.keys())
def print_logs(self, level: int = 0):
diff --git a/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py b/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py
index 41480c4173..9e70b03a0c 100644
--- a/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py
+++ b/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py
@@ -1,5 +1,3 @@
-from typing import Dict
-
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
@@ -41,7 +39,7 @@ def _phonemize(self, text, separator):
return self.phonemize_zh_cn(text, separator)
@staticmethod
- def supported_languages() -> Dict:
+ def supported_languages() -> dict:
return {"zh-cn": "Chinese (China)"}
def version(self) -> str:
diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py
index f653cdf13f..07a8753884 100644
--- a/TTS/tts/utils/text/tokenizer.py
+++ b/TTS/tts/utils/text/tokenizer.py
@@ -1,5 +1,6 @@
import logging
-from typing import Callable, Dict, List, Union
+from collections.abc import Callable
+from typing import Union
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes
@@ -43,7 +44,7 @@ def __init__(
use_phonemes=False,
text_cleaner: Callable = None,
characters: "BaseCharacters" = None,
- phonemizer: Union["Phonemizer", Dict] = None,
+ phonemizer: Union["Phonemizer", dict] = None,
add_blank: bool = False,
use_eos_bos=False,
):
@@ -65,7 +66,7 @@ def characters(self, new_characters):
self.pad_id = self.characters.char_to_id(self.characters.pad) if self.characters.pad else None
self.blank_id = self.characters.char_to_id(self.characters.blank) if self.characters.blank else None
- def encode(self, text: str) -> List[int]:
+ def encode(self, text: str) -> list[int]:
"""Encodes a string of text as a sequence of IDs."""
token_ids = []
for char in text:
@@ -80,14 +81,14 @@ def encode(self, text: str) -> List[int]:
logger.warning("Character %s not found in the vocabulary. Discarding it.", repr(char))
return token_ids
- def decode(self, token_ids: List[int]) -> str:
+ def decode(self, token_ids: list[int]) -> str:
"""Decodes a sequence of IDs to a string of text."""
text = ""
for token_id in token_ids:
text += self.characters.id_to_char(token_id)
return text
- def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint: disable=unused-argument
+ def text_to_ids(self, text: str, language: str = None) -> list[int]: # pylint: disable=unused-argument
"""Converts a string of text to a sequence of token IDs.
Args:
@@ -121,15 +122,15 @@ def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint:
text = self.pad_with_bos_eos(text)
return text
- def ids_to_text(self, id_sequence: List[int]) -> str:
+ def ids_to_text(self, id_sequence: list[int]) -> str:
"""Converts a sequence of token IDs to a string of text."""
return self.decode(id_sequence)
- def pad_with_bos_eos(self, char_sequence: List[str]):
+ def pad_with_bos_eos(self, char_sequence: list[str]):
"""Pads a sequence with the special BOS and EOS characters."""
return [self.characters.bos_id] + list(char_sequence) + [self.characters.eos_id]
- def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
+ def intersperse_blank_char(self, char_sequence: list[str], use_blank_char: bool = False):
"""Intersperses the blank character between characters in a sequence.
Use the ```blank``` character if defined else use the ```pad``` character.
@@ -163,7 +164,7 @@ def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None):
"""
# init cleaners
text_cleaner = None
- if isinstance(config.text_cleaner, (str, list)):
+ if isinstance(config.text_cleaner, str | list):
text_cleaner = getattr(cleaners, config.text_cleaner)
# init characters
diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py
index 0cba7fc8a8..7fd4259178 100644
--- a/TTS/utils/audio/numpy_transforms.py
+++ b/TTS/utils/audio/numpy_transforms.py
@@ -1,7 +1,7 @@
import logging
import os
from io import BytesIO
-from typing import Any, Optional, Union
+from typing import Any
import librosa
import numpy as np
@@ -21,7 +21,7 @@ def build_mel_basis(
fft_size: int,
num_mels: int,
mel_fmin: int,
- mel_fmax: Optional[int] = None,
+ mel_fmax: int | None = None,
**kwargs,
) -> np.ndarray:
"""Build melspectrogram basis.
@@ -177,8 +177,8 @@ def stft(
*,
y: np.ndarray,
fft_size: int,
- hop_length: Optional[int] = None,
- win_length: Optional[int] = None,
+ hop_length: int | None = None,
+ win_length: int | None = None,
pad_mode: str = "reflect",
window: str = "hann",
center: bool = True,
@@ -205,8 +205,8 @@ def stft(
def istft(
*,
y: np.ndarray,
- hop_length: Optional[int] = None,
- win_length: Optional[int] = None,
+ hop_length: int | None = None,
+ win_length: int | None = None,
window: str = "hann",
center: bool = True,
**kwargs,
@@ -248,8 +248,8 @@ def compute_stft_paddings(*, x: np.ndarray, hop_length: int, pad_two_sides: bool
def compute_f0(
*,
x: np.ndarray,
- pitch_fmax: Optional[float] = None,
- pitch_fmin: Optional[float] = None,
+ pitch_fmax: float | None = None,
+ pitch_fmin: float | None = None,
hop_length: int,
win_length: int,
sample_rate: int,
@@ -408,7 +408,7 @@ def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.n
def load_wav(
- *, filename: Union[str, os.PathLike[Any]], sample_rate: Optional[int] = None, resample: bool = False, **kwargs
+ *, filename: str | os.PathLike[Any], sample_rate: int | None = None, resample: bool = False, **kwargs
) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
@@ -437,7 +437,7 @@ def load_wav(
def save_wav(
*,
wav: np.ndarray,
- path: Union[str, os.PathLike[Any]],
+ path: str | os.PathLike[Any],
sample_rate: int,
pipe_out=None,
do_rms_norm: bool = False,
diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py
index bf07333aea..55b8575aa4 100644
--- a/TTS/utils/audio/processor.py
+++ b/TTS/utils/audio/processor.py
@@ -1,6 +1,6 @@
import logging
import os
-from typing import Any, Optional, Union
+from typing import Any
import librosa
import numpy as np
@@ -222,9 +222,9 @@ def __init__(
self.hop_length = hop_length
self.win_length = win_length
assert min_level_db != 0.0, " [!] min_level_db is 0"
- assert (
- self.win_length <= self.fft_size
- ), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}"
+ assert self.win_length <= self.fft_size, (
+ f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}"
+ )
members = vars(self)
logger.info("Setting up Audio Processor...")
for key, value in members.items():
@@ -283,7 +283,9 @@ def normalize(self, S: np.ndarray) -> np.ndarray:
S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm
if self.clip_norm:
S_norm = np.clip(
- S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
+ S_norm,
+ -self.max_norm, # pylint: disable=invalid-unary-operand-type
+ self.max_norm,
)
return S_norm
S_norm = self.max_norm * S_norm
@@ -318,7 +320,9 @@ def denormalize(self, S: np.ndarray) -> np.ndarray:
if self.symmetric_norm:
if self.clip_norm:
S_denorm = np.clip(
- S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
+ S_denorm,
+ -self.max_norm, # pylint: disable=invalid-unary-operand-type
+ self.max_norm,
)
S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db
return S_denorm + self.ref_level_db
@@ -351,9 +355,9 @@ def load_stats(self, stats_path: str) -> tuple[np.array, np.array, np.array, np.
if key in skip_parameters:
continue
if key not in ["sample_rate", "trim_db"]:
- assert (
- stats_config[key] == self.__dict__[key]
- ), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
+ assert stats_config[key] == self.__dict__[key], (
+ f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
+ )
return mel_mean, mel_std, linear_mean, linear_std, stats_config
# pylint: disable=attribute-defined-outside-init
@@ -549,7 +553,7 @@ def sound_norm(x: np.ndarray) -> np.ndarray:
return volume_norm(x=x)
### save and load ###
- def load_wav(self, filename: Union[str, os.PathLike[Any]], sr: Optional[int] = None) -> np.ndarray:
+ def load_wav(self, filename: str | os.PathLike[Any], sr: int | None = None) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
@@ -576,9 +580,7 @@ def load_wav(self, filename: Union[str, os.PathLike[Any]], sr: Optional[int] = N
x = rms_volume_norm(x=x, db_level=self.db_level)
return x
- def save_wav(
- self, wav: np.ndarray, path: Union[str, os.PathLike[Any]], sr: Optional[int] = None, pipe_out=None
- ) -> None:
+ def save_wav(self, wav: np.ndarray, path: str | os.PathLike[Any], sr: int | None = None, pipe_out=None) -> None:
"""Save a waveform to a file using Scipy.
Args:
diff --git a/TTS/utils/capacitron_optimizer.py b/TTS/utils/capacitron_optimizer.py
index 7206ffd508..01f303f98d 100644
--- a/TTS/utils/capacitron_optimizer.py
+++ b/TTS/utils/capacitron_optimizer.py
@@ -1,4 +1,4 @@
-from typing import Generator
+from collections.abc import Generator
from trainer.trainer_utils import get_optimizer
diff --git a/TTS/utils/download.py b/TTS/utils/download.py
index e94b1d68c8..75ef9164f6 100644
--- a/TTS/utils/download.py
+++ b/TTS/utils/download.py
@@ -7,8 +7,9 @@
import urllib
import urllib.request
import zipfile
+from collections.abc import Iterable
from os.path import expanduser
-from typing import Any, Iterable, List, Optional
+from typing import Any
from torch.utils.model_zoo import tqdm
@@ -16,7 +17,7 @@
def stream_url(
- url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
+ url: str, start_byte: int | None = None, block_size: int = 32 * 1024, progress_bar: bool = True
) -> Iterable:
"""Stream url by chunk
@@ -36,7 +37,7 @@ def stream_url(
req = urllib.request.Request(url)
if start_byte:
- req.headers["Range"] = "bytes={}-".format(start_byte)
+ req.headers["Range"] = f"bytes={start_byte}-"
with (
urllib.request.urlopen(req) as upointer,
@@ -61,8 +62,8 @@ def stream_url(
def download_url(
url: str,
download_folder: str,
- filename: Optional[str] = None,
- hash_value: Optional[str] = None,
+ filename: str | None = None,
+ hash_value: str | None = None,
hash_type: str = "sha256",
progress_bar: bool = True,
resume: bool = False,
@@ -88,10 +89,10 @@ def download_url(
filepath = os.path.join(download_folder, filename)
if resume and os.path.exists(filepath):
mode = "ab"
- local_size: Optional[int] = os.path.getsize(filepath)
+ local_size: int | None = os.path.getsize(filepath)
elif not resume and os.path.exists(filepath):
- raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath))
+ raise RuntimeError(f"{filepath} already exists. Delete the file manually and retry.")
else:
mode = "wb"
local_size = None
@@ -100,7 +101,7 @@ def download_url(
with open(filepath, "rb") as file_obj:
if validate_file(file_obj, hash_value, hash_type):
return
- raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
+ raise RuntimeError(f"The hash of {filepath} does not match. Delete the file manually and retry.")
with open(filepath, mode) as fpointer:
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
@@ -108,7 +109,7 @@ def download_url(
with open(filepath, "rb") as file_obj:
if hash_value and not validate_file(file_obj, hash_value, hash_type):
- raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
+ raise RuntimeError(f"The hash of {filepath} does not match. Delete the file manually and retry.")
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
@@ -140,7 +141,7 @@ def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") ->
return hash_func.hexdigest() == hash_value
-def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
+def extract_archive(from_path: str, to_path: str | None = None, overwrite: bool = False) -> list[str]:
"""Extract archive.
Args:
from_path (str): the path of the archive.
diff --git a/TTS/utils/downloaders.py b/TTS/utils/downloaders.py
index 8705873982..c06c2649ad 100644
--- a/TTS/utils/downloaders.py
+++ b/TTS/utils/downloaders.py
@@ -1,6 +1,5 @@
import logging
import os
-from typing import Optional
from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive
@@ -21,7 +20,7 @@ def download_ljspeech(path: str):
extract_archive(archive)
-def download_vctk(path: str, use_kaggle: Optional[bool] = False):
+def download_vctk(path: str, use_kaggle: bool | None = False):
"""Download and extract VCTK dataset.
Args:
@@ -49,7 +48,7 @@ def download_tweb(path: str):
download_kaggle_dataset("bryanpark/the-world-english-bible-speech-dataset", "TWEB", path)
-def download_libri_tts(path: str, subset: Optional[str] = "all"):
+def download_libri_tts(path: str, subset: str | None = "all"):
"""Download and extract libri tts dataset.
Args:
diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py
index 77566c3f6a..e1df6f6ed4 100644
--- a/TTS/utils/generic_utils.py
+++ b/TTS/utils/generic_utils.py
@@ -1,11 +1,11 @@
-# -*- coding: utf-8 -*-
import datetime
import importlib
import logging
import os
import re
+from collections.abc import Callable
from pathlib import Path
-from typing import Any, Callable, Dict, Optional, TextIO, TypeVar, Union
+from typing import Any, TextIO, TypeVar
import torch
from packaging.version import Version
@@ -16,11 +16,11 @@
_T = TypeVar("_T")
-def exists(val: Union[_T, None]) -> TypeIs[_T]:
+def exists(val: _T | None) -> TypeIs[_T]:
return val is not None
-def default(val: Union[_T, None], d: Union[_T, Callable[[], _T]]) -> _T:
+def default(val: _T | None, d: _T | Callable[[], _T]) -> _T:
if exists(val):
return val
return d() if callable(d) else d
@@ -69,7 +69,7 @@ def get_import_path(obj: object) -> str:
return ".".join([type(obj).__module__, type(obj).__name__])
-def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
+def format_aux_input(def_args: dict, kwargs: dict) -> dict:
"""Format kwargs to hande auxilary inputs to models.
Args:
@@ -80,9 +80,9 @@ def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
Dict: arguments with formatted auxilary inputs.
"""
kwargs = kwargs.copy()
- for name in def_args:
+ for name, arg in def_args.items():
if name not in kwargs or kwargs[name] is None:
- kwargs[name] = def_args[name]
+ kwargs[name] = arg
return kwargs
@@ -108,9 +108,9 @@ def setup_logger(
logger_name: str,
level: int = logging.INFO,
*,
- formatter: Optional[logging.Formatter] = None,
- stream: Optional[TextIO] = None,
- log_dir: Optional[Union[str, os.PathLike[Any]]] = None,
+ formatter: logging.Formatter | None = None,
+ stream: TextIO | None = None,
+ log_dir: str | os.PathLike[Any] | None = None,
log_name: str = "log",
) -> None:
"""Set up a logger.
@@ -146,6 +146,6 @@ def is_pytorch_at_least_2_4() -> bool:
return Version(torch.__version__) >= Version("2.4")
-def optional_to_str(x: Optional[Any]) -> str:
+def optional_to_str(x: Any | None) -> str:
"""Convert input to string, using empty string if input is None."""
return "" if x is None else str(x)
diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py
index 5dff1b84c8..20d6ab226b 100644
--- a/TTS/utils/manage.py
+++ b/TTS/utils/manage.py
@@ -6,7 +6,7 @@
import zipfile
from pathlib import Path
from shutil import copyfile, rmtree
-from typing import Any, Optional, TypedDict, Union
+from typing import Any, TypedDict
import fsspec
import requests
@@ -27,12 +27,12 @@ class ModelItem(TypedDict, total=False):
license: str
author: str
contact: str
- commit: Optional[str]
+ commit: str | None
model_hash: str
tos_required: bool
- default_vocoder: Optional[str]
- model_url: Union[str, list[str]]
- github_rls_url: Union[str, list[str]]
+ default_vocoder: str | None
+ model_url: str | list[str]
+ github_rls_url: str | list[str]
hf_url: list[str]
@@ -49,7 +49,7 @@ class ModelItem(TypedDict, total=False):
}
-class ModelManager(object):
+class ModelManager:
tqdm_progress = None
"""Manage TTS models defined in .models.json.
It provides an interface to list and download
@@ -66,8 +66,8 @@ class ModelManager(object):
def __init__(
self,
- models_file: Optional[Union[str, os.PathLike[Any]]] = None,
- output_prefix: Optional[Union[str, os.PathLike[Any]]] = None,
+ models_file: str | os.PathLike[Any] | None = None,
+ output_prefix: str | os.PathLike[Any] | None = None,
progress_bar: bool = False,
) -> None:
super().__init__()
@@ -84,7 +84,7 @@ def __init__(
path = Path(__file__).parent / "../.models.json"
self.read_models_file(path)
- def read_models_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
+ def read_models_file(self, file_path: str | os.PathLike[Any]) -> None:
"""Read .models.json as a dict
Args:
@@ -274,7 +274,7 @@ def set_model_url(model_item: ModelItem) -> ModelItem:
model_item["model_url"] = "https://huggingface.co/coqui/"
return model_item
- def _set_model_item(self, model_name: str) -> tuple[ModelItem, str, str, Optional[str]]:
+ def _set_model_item(self, model_name: str) -> tuple[ModelItem, str, str, str | None]:
# fetch model info from the dict
if "fairseq" in model_name:
model_type, lang, dataset, model = model_name.split("/")
@@ -389,7 +389,7 @@ def check_if_configs_are_equal(self, model_name: str, model_item: ModelItem, out
logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name)
self.create_dir_and_download_model(model_name, model_item, output_path)
- def download_model(self, model_name: str) -> tuple[Path, Optional[Path], ModelItem]:
+ def download_model(self, model_name: str) -> tuple[Path, Path | None, ModelItem]:
"""Download model files given the full model name.
Model name is in the format
'type/language/dataset/model'
@@ -471,7 +471,7 @@ def _find_files(output_path: Path) -> tuple[Path, Path]:
return model_file, config_file
@staticmethod
- def _find_speaker_encoder(output_path: Path) -> Optional[Path]:
+ def _find_speaker_encoder(output_path: Path) -> Path | None:
"""Find the speaker encoder file in the output path
Args:
@@ -523,7 +523,7 @@ def _update_paths(self, output_path: Path, config_path: Path) -> None:
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
@staticmethod
- def _update_path(field_name: str, new_path: Optional[Path], config_path: Path) -> None:
+ def _update_path(field_name: str, new_path: Path | None, config_path: Path) -> None:
"""Update the path in the model config.json for the current environment after download"""
if new_path is not None and new_path.is_file():
config = load_config(str(config_path))
@@ -619,9 +619,7 @@ def _download_tar_file(file_url: str, output_folder: Path, progress_bar: bool) -
rmtree(output_folder / tar_names[0])
@staticmethod
- def _download_model_files(
- file_urls: list[str], output_folder: Union[str, os.PathLike[Any]], progress_bar: bool
- ) -> None:
+ def _download_model_files(file_urls: list[str], output_folder: str | os.PathLike[Any], progress_bar: bool) -> None:
"""Download the github releases"""
output_folder = Path(output_folder)
for file_url in file_urls:
diff --git a/TTS/utils/radam.py b/TTS/utils/radam.py
index cbd14990f3..b893d115c9 100644
--- a/TTS/utils/radam.py
+++ b/TTS/utils/radam.py
@@ -9,16 +9,16 @@
class RAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
if lr < 0.0:
- raise ValueError("Invalid learning rate: {}".format(lr))
+ raise ValueError(f"Invalid learning rate: {lr}")
if eps < 0.0:
- raise ValueError("Invalid epsilon value: {}".format(eps))
+ raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
self.degenerated_to_sgd = degenerated_to_sgd
- if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
+ if isinstance(params, list | tuple) and len(params) > 0 and isinstance(params[0], dict):
for param in params:
if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]):
param["buffer"] = [[None, None, None] for _ in range(10)]
diff --git a/TTS/utils/samplers.py b/TTS/utils/samplers.py
index b08a763a33..d24733977a 100644
--- a/TTS/utils/samplers.py
+++ b/TTS/utils/samplers.py
@@ -1,6 +1,6 @@
import math
import random
-from typing import Callable, List, Union
+from collections.abc import Callable
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler
@@ -49,9 +49,9 @@ def __init__(
label_key="class_name",
):
super().__init__(dataset_items)
- assert (
- batch_size % (num_classes_in_batch * num_gpus) == 0
- ), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)."
+ assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
+ "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)."
+ )
label_indices = {}
for idx, item in enumerate(dataset_items):
@@ -176,7 +176,7 @@ def __init__(
data,
batch_size,
drop_last,
- sort_key: Union[Callable, List] = identity,
+ sort_key: Callable | list = identity,
bucket_size_multiplier=100,
):
super().__init__(sampler, batch_size, drop_last)
diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py
index fafeddfd75..cebb094a48 100644
--- a/TTS/utils/synthesizer.py
+++ b/TTS/utils/synthesizer.py
@@ -2,7 +2,7 @@
import os
import time
from pathlib import Path
-from typing import Any, List, Optional, Union
+from typing import Any
import numpy as np
import pysbd
@@ -30,18 +30,18 @@ class Synthesizer(nn.Module):
def __init__(
self,
*,
- tts_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None,
- tts_config_path: Optional[Union[str, os.PathLike[Any]]] = None,
- tts_speakers_file: Optional[Union[str, os.PathLike[Any]]] = None,
- tts_languages_file: Optional[Union[str, os.PathLike[Any]]] = None,
- vocoder_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None,
- vocoder_config: Optional[Union[str, os.PathLike[Any]]] = None,
- encoder_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None,
- encoder_config: Optional[Union[str, os.PathLike[Any]]] = None,
- vc_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None,
- vc_config: Optional[Union[str, os.PathLike[Any]]] = None,
- model_dir: Optional[Union[str, os.PathLike[Any]]] = None,
- voice_dir: Optional[Union[str, os.PathLike[Any]]] = None,
+ tts_checkpoint: str | os.PathLike[Any] | None = None,
+ tts_config_path: str | os.PathLike[Any] | None = None,
+ tts_speakers_file: str | os.PathLike[Any] | None = None,
+ tts_languages_file: str | os.PathLike[Any] | None = None,
+ vocoder_checkpoint: str | os.PathLike[Any] | None = None,
+ vocoder_config: str | os.PathLike[Any] | None = None,
+ encoder_checkpoint: str | os.PathLike[Any] | None = None,
+ encoder_config: str | os.PathLike[Any] | None = None,
+ vc_checkpoint: str | os.PathLike[Any] | None = None,
+ vc_config: str | os.PathLike[Any] | None = None,
+ model_dir: str | os.PathLike[Any] | None = None,
+ voice_dir: str | os.PathLike[Any] | None = None,
use_cuda: bool = False,
) -> None:
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
@@ -248,7 +248,7 @@ def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> N
if use_cuda:
self.vocoder_model.cuda()
- def split_into_sentences(self, text) -> List[str]:
+ def split_into_sentences(self, text) -> list[str]:
"""Split give text into sentences.
Args:
@@ -259,7 +259,7 @@ def split_into_sentences(self, text) -> List[str]:
"""
return self.seg.segment(text)
- def save_wav(self, wav: List[int], path: str, pipe_out=None) -> None:
+ def save_wav(self, wav: list[int], path: str, pipe_out=None) -> None:
"""Save the waveform as a file.
Args:
@@ -274,7 +274,7 @@ def save_wav(self, wav: List[int], path: str, pipe_out=None) -> None:
wav = np.array(wav)
save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate, pipe_out=pipe_out)
- def voice_conversion(self, source_wav: str, target_wav: Union[str, list[str]], **kwargs) -> List[int]:
+ def voice_conversion(self, source_wav: str, target_wav: str | list[str], **kwargs) -> list[int]:
start_time = time.time()
if not isinstance(target_wav, list):
@@ -302,7 +302,7 @@ def tts(
reference_speaker_name=None,
split_sentences: bool = True,
**kwargs,
- ) -> List[int]:
+ ) -> list[int]:
"""🐸 TTS magic. Run all the models and generate speech.
Args:
diff --git a/TTS/vc/configs/freevc_config.py b/TTS/vc/configs/freevc_config.py
index d600bfb1f4..37f8048b7f 100644
--- a/TTS/vc/configs/freevc_config.py
+++ b/TTS/vc/configs/freevc_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List, Optional
from coqpit import Coqpit
@@ -47,7 +46,7 @@ class FreeVCAudioConfig(Coqpit):
win_length: int = field(default=1280)
n_mel_channels: int = field(default=80)
mel_fmin: float = field(default=0.0)
- mel_fmax: Optional[float] = field(default=None)
+ mel_fmax: float | None = field(default=None)
@dataclass
@@ -122,11 +121,11 @@ class FreeVCArgs(Coqpit):
kernel_size: int = field(default=3)
p_dropout: float = field(default=0.1)
resblock: str = field(default="1")
- resblock_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11])
- resblock_dilation_sizes: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
- upsample_rates: List[int] = field(default_factory=lambda: [10, 8, 2, 2])
+ resblock_kernel_sizes: list[int] = field(default_factory=lambda: [3, 7, 11])
+ resblock_dilation_sizes: list[list[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
+ upsample_rates: list[int] = field(default_factory=lambda: [10, 8, 2, 2])
upsample_initial_channel: int = field(default=512)
- upsample_kernel_sizes: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
+ upsample_kernel_sizes: list[int] = field(default_factory=lambda: [16, 16, 4, 4])
n_layers_q: int = field(default=3)
use_spectral_norm: bool = field(default=False)
gin_channels: int = field(default=256)
@@ -269,7 +268,7 @@ class FreeVCConfig(BaseVCConfig):
# use d-vectors
use_d_vector_file: bool = False
- d_vector_file: List[str] = None
+ d_vector_file: list[str] = None
d_vector_dim: int = None
def __post_init__(self):
diff --git a/TTS/vc/configs/openvoice_config.py b/TTS/vc/configs/openvoice_config.py
index 261cdd6f47..167a61ddb3 100644
--- a/TTS/vc/configs/openvoice_config.py
+++ b/TTS/vc/configs/openvoice_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import Optional
from coqpit import Coqpit
@@ -187,13 +186,13 @@ class OpenVoiceConfig(BaseVCConfig):
# multi-speaker settings
# use speaker embedding layer
num_speakers: int = 0
- speakers_file: Optional[str] = None
+ speakers_file: str | None = None
speaker_embedding_channels: int = 256
# use d-vectors
use_d_vector_file: bool = False
- d_vector_file: Optional[list[str]] = None
- d_vector_dim: Optional[int] = None
+ d_vector_file: list[str] | None = None
+ d_vector_dim: int | None = None
def __post_init__(self) -> None:
for key, val in self.model_args.items():
diff --git a/TTS/vc/configs/shared_configs.py b/TTS/vc/configs/shared_configs.py
index 3c6b1a32cf..b84a97e487 100644
--- a/TTS/vc/configs/shared_configs.py
+++ b/TTS/vc/configs/shared_configs.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
@@ -132,7 +131,7 @@ class BaseVCConfig(BaseTrainingConfig):
shuffle: bool = False
drop_last: bool = False
# dataset
- datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
+ datasets: list[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer
optimizer: str = "radam"
optimizer_params: dict = None
@@ -140,7 +139,7 @@ class BaseVCConfig(BaseTrainingConfig):
lr_scheduler: str = None
lr_scheduler_params: dict = field(default_factory=lambda: {})
# testing
- test_sentences: List[str] = field(default_factory=lambda: [])
+ test_sentences: list[str] = field(default_factory=lambda: [])
# evaluation
eval_split_max_size: int = None
eval_split_size: float = 0.01
diff --git a/TTS/vc/layers/freevc/modules.py b/TTS/vc/layers/freevc/modules.py
index c34f22d701..92df39b5e0 100644
--- a/TTS/vc/layers/freevc/modules.py
+++ b/TTS/vc/layers/freevc/modules.py
@@ -48,7 +48,7 @@ def forward(self, x, x_mask):
class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
- super(WN, self).__init__()
+ super().__init__()
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = (kernel_size,)
@@ -122,7 +122,7 @@ def remove_weight_norm(self):
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
- super(ResBlock1, self).__init__()
+ super().__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
@@ -198,7 +198,7 @@ def remove_weight_norm(self):
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
- super(ResBlock2, self).__init__()
+ super().__init__()
self.convs = nn.ModuleList(
[
weight_norm(
diff --git a/TTS/vc/layers/freevc/speaker_encoder/audio.py b/TTS/vc/layers/freevc/speaker_encoder/audio.py
index 5fa317ce45..5d14bf2f19 100644
--- a/TTS/vc/layers/freevc/speaker_encoder/audio.py
+++ b/TTS/vc/layers/freevc/speaker_encoder/audio.py
@@ -1,5 +1,4 @@
from pathlib import Path
-from typing import Optional, Union
# import webrtcvad
import librosa
@@ -16,7 +15,7 @@
int16_max = (2**15) - 1
-def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], source_sr: Optional[int] = None):
+def preprocess_wav(fpath_or_wav: str | Path | np.ndarray, source_sr: int | None = None):
"""
Applies the preprocessing operations used in training the Speaker Encoder to a waveform
either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
diff --git a/TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py b/TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py
index 62fae59bc1..d2f4ffe394 100644
--- a/TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py
+++ b/TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py
@@ -1,6 +1,5 @@
import logging
from time import perf_counter as timer
-from typing import List
import numpy as np
import torch
@@ -89,7 +88,7 @@ def compute_partial_slices(n_samples: int, rate, min_coverage):
assert 0 < min_coverage <= 1
# Compute how many frames separate two partial utterances
- samples_per_frame = int((sampling_rate * mel_window_step / 1000))
+ samples_per_frame = int(sampling_rate * mel_window_step / 1000)
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
assert 0 < frame_step, "The rate is too high"
@@ -162,7 +161,7 @@ def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_
return embed, partial_embeds, wav_slices
return embed
- def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
+ def embed_speaker(self, wavs: list[np.ndarray], **kwargs):
"""
Compute the embedding of a collection of wavs (presumably from the same speaker) by
averaging their embedding and L2-normalizing it.
diff --git a/TTS/vc/layers/freevc/wavlm/modules.py b/TTS/vc/layers/freevc/wavlm/modules.py
index 37c1a6e877..cf31a866de 100644
--- a/TTS/vc/layers/freevc/wavlm/modules.py
+++ b/TTS/vc/layers/freevc/wavlm/modules.py
@@ -9,7 +9,6 @@
import math
import warnings
-from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
@@ -89,7 +88,7 @@ class Swish(nn.Module):
def __init__(self):
"""Construct an MultiHeadedAttention object."""
- super(Swish, self).__init__()
+ super().__init__()
self.act = torch.nn.Sigmoid()
def forward(self, x):
@@ -98,7 +97,7 @@ def forward(self, x):
class GLU_Linear(nn.Module):
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
- super(GLU_Linear, self).__init__()
+ super().__init__()
self.glu_type = glu_type
self.output_dim = output_dim
@@ -158,7 +157,7 @@ def get_activation_fn(activation: str):
elif activation == "glu":
return lambda x: x
else:
- raise RuntimeError("--activation-fn {} not supported".format(activation))
+ raise RuntimeError(f"--activation-fn {activation} not supported")
def init_bert_params(module):
@@ -219,7 +218,7 @@ def quant_noise(module, p, block_size):
return module
# supported modules
- assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
+ assert isinstance(module, nn.Linear | nn.Embedding | nn.Conv2d)
# test whether module.weight has the right sizes wrt block_size
is_conv = module.weight.ndim == 4
@@ -331,7 +330,7 @@ def __init__(
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
- "Self-attention requires query, key and " "value to be of the same size"
+ "Self-attention requires query, key and value to be of the same size"
)
k_bias = True
@@ -424,17 +423,17 @@ def compute_bias(self, query_length, key_length):
def forward(
self,
query,
- key: Optional[Tensor],
- value: Optional[Tensor],
- key_padding_mask: Optional[Tensor] = None,
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ key: Tensor | None,
+ value: Tensor | None,
+ key_padding_mask: Tensor | None = None,
+ incremental_state: dict[str, dict[str, Tensor | None]] | None = None,
need_weights: bool = True,
static_kv: bool = False,
- attn_mask: Optional[Tensor] = None,
+ attn_mask: Tensor | None = None,
before_softmax: bool = False,
need_head_weights: bool = False,
- position_bias: Optional[Tensor] = None,
- ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ position_bias: Tensor | None = None,
+ ) -> tuple[Tensor, Tensor | None, Tensor | None]:
"""Input shape: Time x Batch x Channel
Args:
@@ -605,7 +604,7 @@ def forward(
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
- prev_key_padding_mask: Optional[Tensor] = None
+ prev_key_padding_mask: Tensor | None = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
@@ -700,7 +699,7 @@ def forward(
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
- attn_weights: Optional[Tensor] = None
+ attn_weights: Tensor | None = None
if need_weights:
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
if not need_head_weights:
@@ -711,12 +710,12 @@ def forward(
@staticmethod
def _append_prev_key_padding_mask(
- key_padding_mask: Optional[Tensor],
- prev_key_padding_mask: Optional[Tensor],
+ key_padding_mask: Tensor | None,
+ prev_key_padding_mask: Tensor | None,
batch_size: int,
src_len: int,
static_kv: bool,
- ) -> Optional[Tensor]:
+ ) -> Tensor | None:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
@@ -748,19 +747,19 @@ def _append_prev_key_padding_mask(
return new_key_padding_mask
def _get_input_buffer(
- self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
- ) -> Dict[str, Optional[Tensor]]:
+ self, incremental_state: dict[str, dict[str, Tensor | None]] | None
+ ) -> dict[str, Tensor | None]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
- empty_result: Dict[str, Optional[Tensor]] = {}
+ empty_result: dict[str, Tensor | None] = {}
return empty_result
def _set_input_buffer(
self,
- incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
- buffer: Dict[str, Optional[Tensor]],
+ incremental_state: dict[str, dict[str, Tensor | None]],
+ buffer: dict[str, Tensor | None],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
diff --git a/TTS/vc/layers/freevc/wavlm/wavlm.py b/TTS/vc/layers/freevc/wavlm/wavlm.py
index 0247ec53c1..6358662e18 100644
--- a/TTS/vc/layers/freevc/wavlm/wavlm.py
+++ b/TTS/vc/layers/freevc/wavlm/wavlm.py
@@ -9,7 +9,7 @@
import logging
import math
-from typing import Any, List, Optional, Tuple
+from typing import Any
import numpy as np
import torch
@@ -33,8 +33,8 @@
def compute_mask_indices(
- shape: Tuple[int, int],
- padding_mask: Optional[torch.Tensor],
+ shape: tuple[int, int],
+ padding_mask: torch.Tensor | None,
mask_prob: float,
mask_length: int,
mask_type: str = "static",
@@ -68,8 +68,7 @@ def compute_mask_indices(
all_num_mask = int(
# add a random number for probabilistic rounding
- mask_prob * all_sz / float(mask_length)
- + np.random.rand()
+ mask_prob * all_sz / float(mask_length) + np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
@@ -80,8 +79,7 @@ def compute_mask_indices(
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(
# add a random number for probabilistic rounding
- mask_prob * sz / float(mask_length)
- + np.random.rand()
+ mask_prob * sz / float(mask_length) + np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
@@ -155,9 +153,7 @@ def arrange(s, e, length, keep_length):
class WavLMConfig:
def __init__(self, cfg=None):
- self.extractor_mode: str = (
- "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
- )
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
self.encoder_layers: int = 12 # num encoder layers in the transformer
self.encoder_embed_dim: int = 768 # encoder embedding dimension
@@ -166,9 +162,7 @@ def __init__(self, cfg=None):
self.activation_fn: str = "gelu" # activation function to use
self.layer_norm_first: bool = False # apply layernorm first in the transformer
- self.conv_feature_layers: str = (
- "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
- )
+ self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
self.conv_bias: bool = False # include bias in conv encoder
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
@@ -225,7 +219,7 @@ def __init__(
cfg: WavLMConfig,
) -> None:
super().__init__()
- logger.info(f"WavLM Config: {cfg.__dict__}")
+ logger.info("WavLM Config: %s", cfg.__dict__)
self.cfg = cfg
feature_enc_layers = eval(cfg.conv_feature_layers)
@@ -317,10 +311,10 @@ def forward_padding_mask(
def extract_features(
self,
source: torch.Tensor,
- padding_mask: Optional[torch.Tensor] = None,
+ padding_mask: torch.Tensor | None = None,
mask: bool = False,
ret_conv: bool = False,
- output_layer: Optional[int] = None,
+ output_layer: int | None = None,
ret_layer_results: bool = False,
) -> tuple[torch.Tensor, dict[str, Any]]:
if self.feature_grad_mult > 0:
@@ -367,7 +361,7 @@ def extract_features(
class ConvFeatureExtractionModel(nn.Module):
def __init__(
self,
- conv_layers: List[Tuple[int, int, int]],
+ conv_layers: list[tuple[int, int, int]],
dropout: float = 0.0,
mode: str = "default",
conv_bias: bool = False,
diff --git a/TTS/vc/models/__init__.py b/TTS/vc/models/__init__.py
index 8151a0445e..859eaeb2a7 100644
--- a/TTS/vc/models/__init__.py
+++ b/TTS/vc/models/__init__.py
@@ -1,7 +1,6 @@
import importlib
import logging
import re
-from typing import Dict, List, Optional, Union
from TTS.vc.configs.shared_configs import BaseVCConfig
from TTS.vc.models.base_vc import BaseVC
diff --git a/TTS/vc/models/base_vc.py b/TTS/vc/models/base_vc.py
index 6f7fb192b0..a953b901e8 100644
--- a/TTS/vc/models/base_vc.py
+++ b/TTS/vc/models/base_vc.py
@@ -1,7 +1,7 @@
import logging
import os
import random
-from typing import Any, Optional, Union
+from typing import Any
import torch
import torch.distributed as dist
@@ -37,9 +37,9 @@ class BaseVC(BaseTrainerModel):
def __init__(
self,
config: Coqpit,
- ap: Optional[AudioProcessor] = None,
- speaker_manager: Optional[SpeakerManager] = None,
- language_manager: Optional[LanguageManager] = None,
+ ap: AudioProcessor | None = None,
+ speaker_manager: SpeakerManager | None = None,
+ language_manager: LanguageManager | None = None,
) -> None:
super().__init__()
self.config = config
@@ -69,7 +69,7 @@ def _set_model_args(self, config: Coqpit) -> None:
else:
raise ValueError("config must be either a *Config or *Args")
- def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
+ def init_multispeaker(self, config: Coqpit, data: list[Any] | None = None) -> None:
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers.
@@ -106,7 +106,7 @@ def get_aux_input(self, **kwargs: Any) -> dict[str, Any]:
"""Prepare and return `aux_input` used by `forward()`"""
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
- def get_aux_input_from_test_sentences(self, sentence_info: Union[str, list[str]]) -> dict[str, Any]:
+ def get_aux_input_from_test_sentences(self, sentence_info: str | list[str]) -> dict[str, Any]:
if hasattr(self.config, "model_args"):
config = self.config.model_args
else:
@@ -199,9 +199,9 @@ def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
extra_frames = dur.sum() - mel_lengths[idx]
largest_idxs = torch.argsort(-dur)[:extra_frames]
dur[largest_idxs] -= 1
- assert (
- dur.sum() == mel_lengths[idx]
- ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
+ assert dur.sum() == mel_lengths[idx], (
+ f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
+ )
durations[idx, : text_lengths[idx]] = dur
# set stop targets wrt reduction factor
@@ -275,10 +275,10 @@ def get_data_loader(
config: Coqpit,
assets: dict,
is_eval: bool,
- samples: Union[list[dict], list[list]],
+ samples: list[dict] | list[list],
verbose: bool,
num_gpus: int,
- rank: Optional[int] = None,
+ rank: int | None = None,
) -> "DataLoader":
if is_eval and not config.run_eval:
loader = None
@@ -402,13 +402,11 @@ def test_run(self, assets: dict) -> tuple[dict, dict]:
use_griffin_lim=True,
do_trim_silence=False,
)
- test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
- test_figures["{}-prediction".format(idx)] = plot_spectrogram(
+ test_audios[f"{idx}-audio"] = outputs_dict["wav"]
+ test_figures[f"{idx}-prediction"] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
)
- test_figures["{}-alignment".format(idx)] = plot_alignment(
- outputs_dict["outputs"]["alignments"], output_fig=False
- )
+ test_figures[f"{idx}-alignment"] = plot_alignment(outputs_dict["outputs"]["alignments"], output_fig=False)
return test_figures, test_audios
def on_init_start(self, trainer: Trainer) -> None:
diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py
index 104ad9ae6c..59af40a836 100644
--- a/TTS/vc/models/freevc.py
+++ b/TTS/vc/models/freevc.py
@@ -1,5 +1,4 @@
import logging
-from typing import Optional, Tuple, Union
import librosa
import numpy as np
@@ -102,7 +101,7 @@ def __init__(
upsample_kernel_sizes,
gin_channels=0,
):
- super(Generator, self).__init__()
+ super().__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
@@ -165,7 +164,7 @@ def remove_weight_norm(self):
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
- super(MultiPeriodDiscriminator, self).__init__()
+ super().__init__()
periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
@@ -190,7 +189,7 @@ def forward(self, y, y_hat):
class SpeakerEncoder(torch.nn.Module):
def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
- super(SpeakerEncoder, self).__init__()
+ super().__init__()
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU()
@@ -331,15 +330,15 @@ def forward(
self,
c: torch.Tensor,
spec: torch.Tensor,
- g: Optional[torch.Tensor] = None,
- mel: Optional[torch.Tensor] = None,
- c_lengths: Optional[torch.Tensor] = None,
- spec_lengths: Optional[torch.Tensor] = None,
- ) -> Tuple[
+ g: torch.Tensor | None = None,
+ mel: torch.Tensor | None = None,
+ c_lengths: torch.Tensor | None = None,
+ spec_lengths: torch.Tensor | None = None,
+ ) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""
Forward pass of the model.
@@ -431,7 +430,7 @@ def load_audio(self, wav):
return wav.float()
@torch.inference_mode()
- def voice_conversion(self, src: Union[str, torch.Tensor], tgt: list[Union[str, torch.Tensor]]):
+ def voice_conversion(self, src: str | torch.Tensor, tgt: list[str | torch.Tensor]):
"""
Voice conversion pass of the model.
diff --git a/TTS/vc/models/knnvc.py b/TTS/vc/models/knnvc.py
index 2f504704ef..c31f52e749 100644
--- a/TTS/vc/models/knnvc.py
+++ b/TTS/vc/models/knnvc.py
@@ -1,12 +1,11 @@
import logging
import os
-from typing import Any, Optional, Union
+from typing import Any, TypeAlias
import torch
import torch.nn.functional as F
import torchaudio
from coqpit import Coqpit
-from typing_extensions import TypeAlias
from TTS.vc.configs.knnvc_config import KNNVCConfig
from TTS.vc.layers.freevc.wavlm import get_wavlm
@@ -14,7 +13,7 @@
logger = logging.getLogger(__name__)
-PathOrTensor: TypeAlias = Union[str, os.PathLike[Any], torch.Tensor]
+PathOrTensor: TypeAlias = str | os.PathLike[Any] | torch.Tensor
class KNNVC(BaseVC):
@@ -74,7 +73,7 @@ def get_features(self, audio: PathOrTensor, vad_trigger_level=0) -> torch.Tensor
x, sr = torchaudio.load(audio, normalize=True)
if not sr == self.config.audio.sample_rate:
- logger.info(f"Resampling {sr} to {self.config.audio.sample_rate} in {audio}")
+ logger.info("Resampling %d to %d in %s", sr, self.config.audio.sample_rate, audio)
x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=self.config.audio.sample_rate)
sr = self.config.audio.sample_rate
@@ -126,9 +125,9 @@ def match(
self,
query_seq: torch.Tensor,
matching_set: torch.Tensor,
- synth_set: Optional[torch.Tensor] = None,
- topk: Optional[int] = None,
- target_duration: Optional[float] = None,
+ synth_set: torch.Tensor | None = None,
+ topk: int | None = None,
+ target_duration: float | None = None,
) -> torch.Tensor:
"""Given `query_seq`, `matching_set`, and `synth_set` tensors of shape (N, dim), perform kNN regression matching
with k=`topk`.
@@ -162,7 +161,7 @@ def match(
out_feats = synth_set[best.indices].mean(dim=1)
return out_feats.unsqueeze(0)
- def load_checkpoint(self, vc_config: KNNVCConfig, _vc_checkpoint: Union[str, os.PathLike[Any]]) -> None:
+ def load_checkpoint(self, vc_config: KNNVCConfig, _vc_checkpoint: str | os.PathLike[Any]) -> None:
"""kNN-VC does not use checkpoints."""
def forward(self) -> None: ...
@@ -173,7 +172,7 @@ def voice_conversion(
self,
source: PathOrTensor,
target: list[PathOrTensor],
- topk: Optional[int] = None,
+ topk: int | None = None,
) -> torch.Tensor:
if not isinstance(target, list):
target = [target]
diff --git a/TTS/vc/models/openvoice.py b/TTS/vc/models/openvoice.py
index 3cb37e64b5..1049a580c7 100644
--- a/TTS/vc/models/openvoice.py
+++ b/TTS/vc/models/openvoice.py
@@ -1,8 +1,9 @@
import json
import logging
import os
+from collections.abc import Mapping
from pathlib import Path
-from typing import Any, Mapping, Optional, Union
+from typing import Any
import librosa
import numpy as np
@@ -117,7 +118,7 @@ class OpenVoice(BaseVC):
October 2023, serving as the backend of MyShell.
"""
- def __init__(self, config: Coqpit, speaker_manager: Optional[SpeakerManager] = None) -> None:
+ def __init__(self, config: Coqpit, speaker_manager: SpeakerManager | None = None) -> None:
super().__init__(config, None, speaker_manager, None)
self.init_multispeaker(config)
@@ -178,7 +179,7 @@ def __init__(self, config: Coqpit, speaker_manager: Optional[SpeakerManager] = N
def init_from_config(config: OpenVoiceConfig) -> "OpenVoice":
return OpenVoice(config)
- def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
+ def init_multispeaker(self, config: Coqpit, data: list[Any] | None = None) -> None:
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
@@ -195,7 +196,7 @@ def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) ->
def load_checkpoint(
self,
config: OpenVoiceConfig,
- checkpoint_path: Union[str, os.PathLike[Any]],
+ checkpoint_path: str | os.PathLike[Any],
eval: bool = False,
strict: bool = True,
cache: bool = False,
@@ -219,7 +220,7 @@ def train_step(self) -> None: ...
def eval_step(self) -> None: ...
@staticmethod
- def _set_x_lengths(x: torch.Tensor, aux_input: Mapping[str, Optional[torch.Tensor]]) -> torch.Tensor:
+ def _set_x_lengths(x: torch.Tensor, aux_input: Mapping[str, torch.Tensor | None]) -> torch.Tensor:
if "x_lengths" in aux_input and aux_input["x_lengths"] is not None:
return aux_input["x_lengths"]
return torch.tensor(x.shape[-1:]).to(x.device)
@@ -228,7 +229,7 @@ def _set_x_lengths(x: torch.Tensor, aux_input: Mapping[str, Optional[torch.Tenso
def inference(
self,
x: torch.Tensor,
- aux_input: Mapping[str, Optional[torch.Tensor]] = {"x_lengths": None, "g_src": None, "g_tgt": None},
+ aux_input: Mapping[str, torch.Tensor | None] = {"x_lengths": None, "g_src": None, "g_tgt": None},
) -> dict[str, torch.Tensor]:
"""
Inference pass of the model
@@ -267,7 +268,7 @@ def inference(
"z_hat": z_hat,
}
- def load_audio(self, wav: Union[str, npt.NDArray[np.float32], torch.Tensor, list[float]]) -> torch.Tensor:
+ def load_audio(self, wav: str | npt.NDArray[np.float32] | torch.Tensor | list[float]) -> torch.Tensor:
"""Read and format the input audio."""
if isinstance(wav, str):
out = torch.from_numpy(librosa.load(wav, sr=self.config.audio.input_sample_rate)[0])
@@ -279,7 +280,7 @@ def load_audio(self, wav: Union[str, npt.NDArray[np.float32], torch.Tensor, list
out = wav
return out.to(self.device).float()
- def extract_se(self, audio: Union[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
+ def extract_se(self, audio: str | torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
y = self.load_audio(audio)
y = y.to(self.device)
y = y.unsqueeze(0)
@@ -296,9 +297,7 @@ def extract_se(self, audio: Union[str, torch.Tensor]) -> tuple[torch.Tensor, tor
return g, spec
@torch.inference_mode()
- def voice_conversion(
- self, src: Union[str, torch.Tensor], tgt: list[Union[str, torch.Tensor]]
- ) -> npt.NDArray[np.float32]:
+ def voice_conversion(self, src: str | torch.Tensor, tgt: list[str | torch.Tensor]) -> npt.NDArray[np.float32]:
"""
Voice conversion pass of the model.
diff --git a/TTS/vocoder/configs/multiband_melgan_config.py b/TTS/vocoder/configs/multiband_melgan_config.py
index 763113537f..2139f47b0e 100644
--- a/TTS/vocoder/configs/multiband_melgan_config.py
+++ b/TTS/vocoder/configs/multiband_melgan_config.py
@@ -121,7 +121,7 @@ class MultibandMelganConfig(BaseGANVocoderConfig):
pad_short: int = 2000
use_noise_augment: bool = False
use_cache: bool = True
- steps_to_start_discriminator: bool = 200000
+ steps_to_start_discriminator: int = 200000
# LOSS PARAMETERS - overrides
use_stft_loss: bool = True
diff --git a/TTS/vocoder/configs/shared_configs.py b/TTS/vocoder/configs/shared_configs.py
index a558cfcabb..548505a54d 100644
--- a/TTS/vocoder/configs/shared_configs.py
+++ b/TTS/vocoder/configs/shared_configs.py
@@ -168,7 +168,7 @@ class BaseGANVocoderConfig(BaseVocoderConfig):
target_loss: str = "loss_0" # loss value to pick the best model to save after each epoch
# optimizer
- grad_clip: float = field(default_factory=lambda: [5, 5])
+ grad_clip: float | list[float] = field(default_factory=lambda: [5, 5])
lr_gen: float = 0.0002 # Initial learning rate.
lr_disc: float = 0.0002 # Initial learning rate.
lr_scheduler_gen: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
@@ -178,5 +178,5 @@ class BaseGANVocoderConfig(BaseVocoderConfig):
scheduler_after_epoch: bool = True
use_pqmf: bool = False # enable/disable using pqmf for multi-band training. (Multi-band MelGAN)
- steps_to_start_discriminator = 0 # start training the discriminator after this number of steps.
+ steps_to_start_discriminator: int = 0 # start training the discriminator after this number of steps.
diff_samples_for_G_and_D: bool = False # use different samples for G and D training steps.
diff --git a/TTS/vocoder/configs/univnet_config.py b/TTS/vocoder/configs/univnet_config.py
index 67f324cfce..85662831ee 100644
--- a/TTS/vocoder/configs/univnet_config.py
+++ b/TTS/vocoder/configs/univnet_config.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import Dict
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
@@ -96,7 +95,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
# model specific params
discriminator_model: str = "univnet_discriminator"
generator_model: str = "univnet_generator"
- generator_model_params: Dict = field(
+ generator_model_params: dict = field(
default_factory=lambda: {
"in_channels": 64,
"out_channels": 1,
@@ -121,7 +120,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
# loss weights - overrides
stft_loss_weight: float = 2.5
- stft_loss_params: Dict = field(
+ stft_loss_params: dict = field(
default_factory=lambda: {
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
@@ -133,7 +132,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 0
l1_spec_loss_weight: float = 0
- l1_spec_loss_params: Dict = field(
+ l1_spec_loss_params: dict = field(
default_factory=lambda: {
"use_mel": True,
"sample_rate": 22050,
@@ -153,7 +152,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
# lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
# lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
- optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0})
+ optimizer_params: dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0})
steps_to_start_discriminator: int = 200000
def __post_init__(self):
diff --git a/TTS/vocoder/datasets/__init__.py b/TTS/vocoder/datasets/__init__.py
index 04462817a8..cef6a50b05 100644
--- a/TTS/vocoder/datasets/__init__.py
+++ b/TTS/vocoder/datasets/__init__.py
@@ -1,5 +1,3 @@
-from typing import List
-
from coqpit import Coqpit
from torch.utils.data import Dataset
@@ -10,7 +8,7 @@
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
-def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List) -> Dataset:
+def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: list) -> Dataset:
if config.model.lower() in "gan":
dataset = GANDataset(
ap=ap,
diff --git a/TTS/vocoder/datasets/gan_dataset.py b/TTS/vocoder/datasets/gan_dataset.py
index 0806c0d496..076545f8a2 100644
--- a/TTS/vocoder/datasets/gan_dataset.py
+++ b/TTS/vocoder/datasets/gan_dataset.py
@@ -32,7 +32,7 @@ def __init__(
super().__init__()
self.ap = ap
self.item_list = items
- self.compute_feat = not isinstance(items[0], (tuple, list))
+ self.compute_feat = not isinstance(items[0], tuple | list)
self.seq_len = seq_len
self.hop_len = hop_len
self.pad_short = pad_short
@@ -128,9 +128,9 @@ def load_item(self, idx):
# correct the audio length wrt padding applied in stft
audio = np.pad(audio, (0, self.hop_len), mode="edge")
audio = audio[: mel.shape[-1] * self.hop_len]
- assert (
- mel.shape[-1] * self.hop_len == audio.shape[-1]
- ), f" [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}"
+ assert mel.shape[-1] * self.hop_len == audio.shape[-1], (
+ f" [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}"
+ )
audio = torch.from_numpy(audio).float().unsqueeze(0)
mel = torch.from_numpy(mel).float().squeeze(0)
diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py
index 6f34bccb7c..435330bebe 100644
--- a/TTS/vocoder/datasets/wavegrad_dataset.py
+++ b/TTS/vocoder/datasets/wavegrad_dataset.py
@@ -2,7 +2,6 @@
import os
import random
from multiprocessing import Manager
-from typing import List, Tuple
import numpy as np
import torch
@@ -65,7 +64,7 @@ def __getitem__(self, idx):
item = self.load_item(idx)
return item
- def load_test_samples(self, num_samples: int) -> List[Tuple]:
+ def load_test_samples(self, num_samples: int) -> list[tuple]:
"""Return test samples.
Args:
@@ -103,9 +102,9 @@ def load_item(self, idx):
audio = np.pad(
audio, (0, self.seq_len + self.pad_short - len(audio)), mode="constant", constant_values=0.0
)
- assert (
- audio.shape[-1] >= self.seq_len + self.pad_short
- ), f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
+ assert audio.shape[-1] >= self.seq_len + self.pad_short, (
+ f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
+ )
# correct the audio length wrt hop length
p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1]
diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py
index 4c4f5c48df..ffb71177c5 100644
--- a/TTS/vocoder/datasets/wavernn_dataset.py
+++ b/TTS/vocoder/datasets/wavernn_dataset.py
@@ -18,7 +18,7 @@ class WaveRNNDataset(Dataset):
def __init__(self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, return_segments=True):
super().__init__()
self.ap = ap
- self.compute_feat = not isinstance(items[0], (tuple, list))
+ self.compute_feat = not isinstance(items[0], tuple | list)
self.item_list = items
self.seq_len = seq_len
self.hop_len = hop_len
diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py
index 8d4dd725ef..81a1f30884 100644
--- a/TTS/vocoder/layers/losses.py
+++ b/TTS/vocoder/layers/losses.py
@@ -1,5 +1,3 @@
-from typing import Dict, Union
-
import torch
from torch import nn
from torch.nn import functional as F
@@ -226,9 +224,9 @@ class GeneratorLoss(nn.Module):
def __init__(self, C):
super().__init__()
- assert not (
- C.use_mse_gan_loss and C.use_hinge_gan_loss
- ), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
+ assert not (C.use_mse_gan_loss and C.use_hinge_gan_loss), (
+ " [!] Cannot use HingeGANLoss and MSEGANLoss together."
+ )
self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False
self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False
@@ -313,9 +311,9 @@ class DiscriminatorLoss(nn.Module):
def __init__(self, C):
super().__init__()
- assert not (
- C.use_mse_gan_loss and C.use_hinge_gan_loss
- ), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
+ assert not (C.use_mse_gan_loss and C.use_hinge_gan_loss), (
+ " [!] Cannot use HingeGANLoss and MSEGANLoss together."
+ )
self.use_mse_gan_loss = C.use_mse_gan_loss
self.use_hinge_gan_loss = C.use_hinge_gan_loss
@@ -352,7 +350,7 @@ def forward(self, scores_fake, scores_real):
class WaveRNNLoss(nn.Module):
- def __init__(self, wave_rnn_mode: Union[str, int]):
+ def __init__(self, wave_rnn_mode: str | int):
super().__init__()
if wave_rnn_mode == "mold":
self.loss_func = discretized_mix_logistic_loss
@@ -363,6 +361,6 @@ def __init__(self, wave_rnn_mode: Union[str, int]):
else:
raise ValueError(" [!] Unknown mode for Wavernn.")
- def forward(self, y_hat, y) -> Dict:
+ def forward(self, y_hat, y) -> dict:
loss = self.loss_func(y_hat, y)
return {"loss": loss}
diff --git a/TTS/vocoder/layers/lvc_block.py b/TTS/vocoder/layers/lvc_block.py
index 8913a1132e..ab1a56e7fc 100644
--- a/TTS/vocoder/layers/lvc_block.py
+++ b/TTS/vocoder/layers/lvc_block.py
@@ -175,9 +175,9 @@ def location_variable_convolution(x, kernel, bias, dilation, hop_size):
batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
- assert in_length == (
- kernel_length * hop_size
- ), f"length of (x, kernel) is not matched, {in_length} vs {kernel_length * hop_size}"
+ assert in_length == (kernel_length * hop_size), (
+ f"length of (x, kernel) is not matched, {in_length} vs {kernel_length * hop_size}"
+ )
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
diff --git a/TTS/vocoder/layers/wavegrad.py b/TTS/vocoder/layers/wavegrad.py
index 9f1512c6d4..187e7062e2 100644
--- a/TTS/vocoder/layers/wavegrad.py
+++ b/TTS/vocoder/layers/wavegrad.py
@@ -74,7 +74,7 @@ def shif_and_scale(x, scale, shift):
class UBlock(nn.Module):
def __init__(self, input_size, hidden_size, factor, dilation):
super().__init__()
- assert isinstance(dilation, (list, tuple))
+ assert isinstance(dilation, list | tuple)
assert len(dilation) == 4
self.factor = factor
diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py
index 7785d8011c..6abb2dc997 100644
--- a/TTS/vocoder/models/gan.py
+++ b/TTS/vocoder/models/gan.py
@@ -1,5 +1,4 @@
from inspect import signature
-from typing import Dict, List, Tuple
import numpy as np
import torch
@@ -65,7 +64,7 @@ def inference(self, x: torch.Tensor) -> torch.Tensor:
"""
return self.model_g.inference(x)
- def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]:
+ def train_step(self, batch: dict, criterion: dict, optimizer_idx: int) -> tuple[dict, dict]:
"""Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for
network on the current pass.
@@ -185,7 +184,7 @@ def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[
outputs = {"model_outputs": self.y_hat_g}
return outputs, loss_dict
- def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
+ def _log(self, name: str, ap: AudioProcessor, batch: dict, outputs: dict) -> tuple[dict, dict]:
"""Logging shared by the training and evaluation.
Args:
@@ -205,22 +204,32 @@ def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tup
return figures, audios
def train_log(
- self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
- ) -> Tuple[Dict, np.ndarray]:
+ self,
+ batch: dict,
+ outputs: dict,
+ logger: "Logger",
+ assets: dict,
+ steps: int, # pylint: disable=unused-argument
+ ) -> tuple[dict, np.ndarray]:
"""Call `_log()` for training."""
figures, audios = self._log("eval", self.ap, batch, outputs)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@torch.inference_mode()
- def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
+ def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> tuple[dict, dict]:
"""Call `train_step()` with `no_grad()`"""
self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(
- self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
- ) -> Tuple[Dict, np.ndarray]:
+ self,
+ batch: dict,
+ outputs: dict,
+ logger: "Logger",
+ assets: dict,
+ steps: int, # pylint: disable=unused-argument
+ ) -> tuple[dict, np.ndarray]:
"""Call `_log()` for evaluation."""
figures, audios = self._log("eval", self.ap, batch, outputs)
logger.eval_figures(steps, figures)
@@ -259,7 +268,7 @@ def on_train_step_start(self, trainer) -> None:
"""
self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator
- def get_optimizer(self) -> List:
+ def get_optimizer(self) -> list:
"""Initiate and return the GAN optimizers based on the config parameters.
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
@@ -275,7 +284,7 @@ def get_optimizer(self) -> List:
)
return [optimizer2, optimizer1]
- def get_lr(self) -> List:
+ def get_lr(self) -> list:
"""Set the initial learning rates for each optimizer.
Returns:
@@ -283,7 +292,7 @@ def get_lr(self) -> List:
"""
return [self.config.lr_disc, self.config.lr_gen]
- def get_scheduler(self, optimizer) -> List:
+ def get_scheduler(self, optimizer) -> list:
"""Set the schedulers for each optimizer.
Args:
@@ -297,7 +306,7 @@ def get_scheduler(self, optimizer) -> List:
return [scheduler2, scheduler1]
@staticmethod
- def format_batch(batch: List) -> Dict:
+ def format_batch(batch: list) -> dict:
"""Format the batch for training.
Args:
@@ -316,12 +325,12 @@ def format_batch(batch: List) -> Dict:
def get_data_loader( # pylint: disable=no-self-use, unused-argument
self,
config: Coqpit,
- assets: Dict,
+ assets: dict,
is_eval: True,
- samples: List,
+ samples: list,
verbose: bool,
num_gpus: int,
- rank: int = None, # pylint: disable=unused-argument
+ rank: int | None = None, # pylint: disable=unused-argument
):
"""Initiate and return the GAN dataloader.
diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py
index b2100c55b1..308b12ab56 100644
--- a/TTS/vocoder/models/hifigan_generator.py
+++ b/TTS/vocoder/models/hifigan_generator.py
@@ -313,9 +313,7 @@ def remove_weight_norm(self):
remove_parametrizations(self.conv_pre, "weight")
remove_parametrizations(self.conv_post, "weight")
- def load_checkpoint(
- self, config, checkpoint_path, eval=False, cache=False
- ): # pylint: disable=unused-argument, redefined-builtin
+ def load_checkpoint(self, config, checkpoint_path, eval=False, cache=False): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py
index 03c971afa4..53ed700755 100644
--- a/TTS/vocoder/models/melgan_generator.py
+++ b/TTS/vocoder/models/melgan_generator.py
@@ -84,9 +84,7 @@ def remove_weight_norm(self):
except ValueError:
layer.remove_weight_norm()
- def load_checkpoint(
- self, config, checkpoint_path, eval=False, cache=False
- ): # pylint: disable=unused-argument, redefined-builtin
+ def load_checkpoint(self, config, checkpoint_path, eval=False, cache=False): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
diff --git a/TTS/vocoder/models/parallel_wavegan_discriminator.py b/TTS/vocoder/models/parallel_wavegan_discriminator.py
index 211d45d91c..02ad60e0ff 100644
--- a/TTS/vocoder/models/parallel_wavegan_discriminator.py
+++ b/TTS/vocoder/models/parallel_wavegan_discriminator.py
@@ -71,7 +71,7 @@ def forward(self, x):
def apply_weight_norm(self):
def _apply_weight_norm(m):
- if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
+ if isinstance(m, torch.nn.Conv1d | torch.nn.Conv2d):
torch.nn.utils.parametrizations.weight_norm(m)
self.apply(_apply_weight_norm)
@@ -174,7 +174,7 @@ def forward(self, x):
def apply_weight_norm(self):
def _apply_weight_norm(m):
- if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
+ if isinstance(m, torch.nn.Conv1d | torch.nn.Conv2d):
torch.nn.utils.parametrizations.weight_norm(m)
self.apply(_apply_weight_norm)
diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py
index 0659a00cc1..71b38d4c0d 100644
--- a/TTS/vocoder/models/parallel_wavegan_generator.py
+++ b/TTS/vocoder/models/parallel_wavegan_generator.py
@@ -108,9 +108,9 @@ def forward(self, c):
# perform upsampling
if c is not None and self.upsample_net is not None:
c = self.upsample_net(c)
- assert (
- c.shape[-1] == x.shape[-1]
- ), f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}"
+ assert c.shape[-1] == x.shape[-1], (
+ f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}"
+ )
# encode to hidden representation
x = self.first_conv(x)
@@ -145,7 +145,7 @@ def _remove_weight_norm(m):
def apply_weight_norm(self):
def _apply_weight_norm(m):
- if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
+ if isinstance(m, torch.nn.Conv1d | torch.nn.Conv2d):
torch.nn.utils.parametrizations.weight_norm(m)
logger.info("Weight norm is applied to %s", m)
@@ -155,9 +155,7 @@ def _apply_weight_norm(m):
def receptive_field_size(self):
return _get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
- def load_checkpoint(
- self, config, checkpoint_path, eval=False, cache=False
- ): # pylint: disable=unused-argument, redefined-builtin
+ def load_checkpoint(self, config, checkpoint_path, eval=False, cache=False): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
diff --git a/TTS/vocoder/models/univnet_generator.py b/TTS/vocoder/models/univnet_generator.py
index 19f5648f4d..d991941441 100644
--- a/TTS/vocoder/models/univnet_generator.py
+++ b/TTS/vocoder/models/univnet_generator.py
@@ -1,5 +1,4 @@
import logging
-from typing import List
import numpy as np
import torch
@@ -21,7 +20,7 @@ def __init__(
out_channels: int,
hidden_channels: int,
cond_channels: int,
- upsample_factors: List[int],
+ upsample_factors: list[int],
lvc_layers_each_block: int,
lvc_kernel_size: int,
kpnet_hidden_channels: int,
@@ -128,7 +127,7 @@ def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m):
- if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
+ if isinstance(m, torch.nn.Conv1d | torch.nn.Conv2d):
torch.nn.utils.parametrizations.weight_norm(m)
logger.info("Weight norm is applied to %s", m)
diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py
index d756f956dd..5aa8ce5bb9 100644
--- a/TTS/vocoder/models/wavegrad.py
+++ b/TTS/vocoder/models/wavegrad.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import Dict, List, Tuple
import numpy as np
import torch
@@ -25,10 +24,10 @@ class WavegradArgs(Coqpit):
use_weight_norm: bool = False
y_conv_channels: int = 32
x_conv_channels: int = 768
- dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512])
- ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128])
- upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2])
- upsample_dilations: List[List[int]] = field(
+ dblock_out_channels: list[int] = field(default_factory=lambda: [128, 128, 256, 512])
+ ublock_out_channels: list[int] = field(default_factory=lambda: [512, 512, 256, 128, 128])
+ upsample_factors: list[int] = field(default_factory=lambda: [4, 4, 4, 2, 2])
+ upsample_dilations: list[list[int]] = field(
default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]]
)
@@ -218,9 +217,7 @@ def apply_weight_norm(self):
self.out_conv = weight_norm(self.out_conv)
self.y_conv = weight_norm(self.y_conv)
- def load_checkpoint(
- self, config, checkpoint_path, eval=False, cache=False
- ): # pylint: disable=unused-argument, redefined-builtin
+ def load_checkpoint(self, config, checkpoint_path, eval=False, cache=False): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
@@ -242,7 +239,7 @@ def load_checkpoint(
)
self.compute_noise_level(betas)
- def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
+ def train_step(self, batch: dict, criterion: dict) -> tuple[dict, dict]:
# format data
x = batch["input"]
y = batch["waveform"]
@@ -258,20 +255,30 @@ def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
return {"model_output": noise_hat}, {"loss": loss}
def train_log( # pylint: disable=no-self-use
- self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
- ) -> Tuple[Dict, np.ndarray]:
+ self,
+ batch: dict,
+ outputs: dict,
+ logger: "Logger",
+ assets: dict,
+ steps: int, # pylint: disable=unused-argument
+ ) -> tuple[dict, np.ndarray]:
pass
@torch.inference_mode()
- def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
+ def eval_step(self, batch: dict, criterion: nn.Module) -> tuple[dict, dict]:
return self.train_step(batch, criterion)
def eval_log( # pylint: disable=no-self-use
- self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
+ self,
+ batch: dict,
+ outputs: dict,
+ logger: "Logger",
+ assets: dict,
+ steps: int, # pylint: disable=unused-argument
) -> None:
pass
- def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument
+ def test(self, assets: dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument
# setup noise schedule and inference
ap = assets["audio_processor"]
noise_schedule = self.config["test_noise_schedule"]
@@ -302,13 +309,22 @@ def get_criterion():
return torch.nn.L1Loss()
@staticmethod
- def format_batch(batch: Dict) -> Dict:
+ def format_batch(batch: dict) -> dict:
# return a whole audio segment
m, y = batch[0], batch[1]
y = y.unsqueeze(1)
return {"input": m, "waveform": y}
- def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
+ def get_data_loader(
+ self,
+ config: Coqpit,
+ assets: dict,
+ is_eval: True,
+ samples: list,
+ verbose: bool,
+ num_gpus: int,
+ rank: int | None = None,
+ ):
ap = assets["audio_processor"]
dataset = WaveGradDataset(
ap=ap,
diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py
index 4ece55af62..fb95d47589 100644
--- a/TTS/vocoder/models/wavernn.py
+++ b/TTS/vocoder/models/wavernn.py
@@ -1,7 +1,6 @@
import sys
import time
from dataclasses import dataclass, field
-from typing import Dict, List, Tuple
import numpy as np
import torch
@@ -171,7 +170,7 @@ class WavernnArgs(Coqpit):
num_res_blocks: int = 10
use_aux_net: bool = True
use_upsample_net: bool = True
- upsample_factors: List[int] = field(default_factory=lambda: [4, 8, 8])
+ upsample_factors: list[int] = field(default_factory=lambda: [4, 8, 8])
mode: str = "mold" # mold [string], gauss [string], bits [int]
mulaw: bool = True # apply mulaw if mode is bits
pad: int = 2
@@ -226,9 +225,9 @@ class of models has however remained an elusive problem. With a focus on text-to
self.aux_dims = self.args.res_out_dims // 4
if self.args.use_upsample_net:
- assert (
- np.cumprod(self.args.upsample_factors)[-1] == config.audio.hop_length
- ), " [!] upsample scales needs to be equal to hop_length"
+ assert np.cumprod(self.args.upsample_factors)[-1] == config.audio.hop_length, (
+ " [!] upsample scales needs to be equal to hop_length"
+ )
self.upsample = UpsampleNetwork(
self.args.feat_dims,
self.args.upsample_factors,
@@ -528,16 +527,14 @@ def xfade_and_unfold(y, target, overlap):
return unfolded
- def load_checkpoint(
- self, config, checkpoint_path, eval=False, cache=False
- ): # pylint: disable=unused-argument, redefined-builtin
+ def load_checkpoint(self, config, checkpoint_path, eval=False, cache=False): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
- def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
+ def train_step(self, batch: dict, criterion: dict) -> tuple[dict, dict]:
mels = batch["input"]
waveform = batch["waveform"]
waveform_coarse = batch["waveform_coarse"]
@@ -552,13 +549,16 @@ def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
loss_dict = criterion(y_hat, waveform_coarse)
return {"model_output": y_hat}, loss_dict
- def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
+ def eval_step(self, batch: dict, criterion: dict) -> tuple[dict, dict]:
return self.train_step(batch, criterion)
@torch.no_grad()
def test(
- self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument
- ) -> Tuple[Dict, Dict]:
+ self,
+ assets: dict,
+ test_loader: "DataLoader",
+ output: dict, # pylint: disable=unused-argument
+ ) -> tuple[dict, dict]:
ap = self.ap
figures = {}
audios = {}
@@ -579,14 +579,18 @@ def test(
return figures, audios
def test_log(
- self, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
- ) -> Tuple[Dict, np.ndarray]:
+ self,
+ outputs: dict,
+ logger: "Logger",
+ assets: dict,
+ steps: int, # pylint: disable=unused-argument
+ ) -> tuple[dict, np.ndarray]:
figures, audios = outputs
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@staticmethod
- def format_batch(batch: Dict) -> Dict:
+ def format_batch(batch: dict) -> dict:
waveform = batch[0]
mels = batch[1]
waveform_coarse = batch[2]
@@ -595,11 +599,12 @@ def format_batch(batch: Dict) -> Dict:
def get_data_loader( # pylint: disable=no-self-use
self,
config: Coqpit,
- assets: Dict,
+ assets: dict,
is_eval: True,
- samples: List,
+ samples: list,
verbose: bool,
num_gpus: int,
+ rank: int | None = None,
):
ap = self.ap
dataset = WaveRNNDataset(
diff --git a/TTS/vocoder/utils/distribution.py b/TTS/vocoder/utils/distribution.py
index fe706ba9ff..bef68e5564 100644
--- a/TTS/vocoder/utils/distribution.py
+++ b/TTS/vocoder/utils/distribution.py
@@ -12,7 +12,7 @@ def gaussian_loss(y_hat, y, log_std_min=-7.0):
mean = y_hat[:, :, :1]
log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
# TODO: replace with pytorch dist
- log_probs = -0.5 * (-math.log(2.0 * math.pi) - 2.0 * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std)))
+ log_probs = -0.5 * (-math.log(2.0 * math.pi) - 2.0 * log_std - torch.pow(y - mean, 2) * torch.exp(-2.0 * log_std))
return log_probs.squeeze().mean()
diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py
index ac797d97f7..2823d206a0 100644
--- a/TTS/vocoder/utils/generic_utils.py
+++ b/TTS/vocoder/utils/generic_utils.py
@@ -1,5 +1,4 @@
import logging
-from typing import Dict
import numpy as np
import torch
@@ -32,7 +31,7 @@ def interpolate_vocoder_input(scale_factor, spec):
return spec
-def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
+def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> dict:
"""Plot the predicted and the real waveform and their spectrograms.
Args:
diff --git a/notebooks/dataset_analysis/analyze.py b/notebooks/dataset_analysis/analyze.py
index 4855886efd..44bf25c071 100644
--- a/notebooks/dataset_analysis/analyze.py
+++ b/notebooks/dataset_analysis/analyze.py
@@ -43,7 +43,7 @@ def process_meta_data(path):
meta_data = {}
# load meta data
- with open(path, "r", encoding="utf-8") as f:
+ with open(path, encoding="utf-8") as f:
data = csv.reader(f, delimiter="|")
for row in data:
frames = int(row[2])
@@ -58,7 +58,7 @@ def process_meta_data(path):
"utt": utt,
"frames": frames,
"audio_len": audio_len,
- "row": "{}|{}|{}|{}".format(row[0], row[1], row[2], row[3]),
+ "row": f"{row[0]}|{row[1]}|{row[2]}|{row[3]}",
}
)
@@ -156,7 +156,7 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
phonemes = {}
- with open(train_path, "r", encoding="utf-8") as f:
+ with open(train_path, encoding="utf-8") as f:
data = csv.reader(f, delimiter="|")
phonemes["None"] = 0
for row in data:
diff --git a/pyproject.toml b/pyproject.toml
index 4b87a10b20..83473f411d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -25,10 +25,10 @@ build-backend = "hatchling.build"
[project]
name = "coqui-tts"
-version = "0.25.3"
+version = "0.26.0"
description = "Deep learning for Text to Speech."
readme = "README.md"
-requires-python = ">=3.9, <3.13"
+requires-python = ">=3.10, <3.13"
license = {text = "MPL-2.0"}
authors = [
{name = "Eren Gölge", email = "egolge@coqui.ai"}
@@ -39,7 +39,6 @@ maintainers = [
classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
@@ -77,18 +76,18 @@ dependencies = [
"matplotlib>=3.7.0",
# Coqui stack
"coqui-tts-trainer>=0.2.0,<0.3.0",
- "coqpit-config>=0.1.1,<0.2.0",
+ "coqpit-config>=0.2.0,<0.3.0",
"monotonic-alignment-search>=0.1.0",
# Gruut + supported languages
"gruut[de,es,fr]>=2.4.0",
# Tortoise
"einops>=0.6.0",
- "transformers>=4.43.0,<=4.46.2",
+ "transformers>=4.47.0",
# Bark
"encodec>=0.1.1",
# XTTS
"num2words>=0.5.14",
- "spacy[ja]>=3,<3.8",
+ "spacy[ja]>=3.2,<3.8",
]
[project.optional-dependencies]
@@ -116,7 +115,7 @@ ko = [
]
# Japanese
ja = [
- "mecab-python3>=1.0.2",
+ "mecab-python3>=1.0.6",
"unidic-lite==1.0.8",
"cutlet>=0.2.0",
]
@@ -136,11 +135,10 @@ all = [
[dependency-groups]
dev = [
- "black==24.2.0",
"coverage[toml]>=7",
- "pre-commit>=3",
+ "pre-commit>=4",
"pytest>=8",
- "ruff==0.7.0",
+ "ruff==0.9.1",
]
# Dependencies for building the documentation
docs = [
@@ -192,6 +190,7 @@ lint.extend-select = [
"F704", # yield-outside-function
"F706", # return-outside-function
"F841", # unused-variable
+ "G004", # no f-string in logging
"I", # import sorting
"PIE790", # unnecessary-pass
"PLC",
@@ -201,6 +200,7 @@ lint.extend-select = [
"PLR0911", # too-many-return-statements
"PLR1711", # useless-return
"PLW",
+ "UP", # pyupgrade
"W291", # trailing-whitespace
"NPY201", # NumPy 2.0 deprecation
]
@@ -231,10 +231,6 @@ max-returns = 7
"E402", # module level import not at top of file
]
-[tool.black]
-line-length = 120
-target-version = ['py39']
-
[tool.coverage.report]
skip_covered = true
skip_empty = true
diff --git a/tests/__init__.py b/tests/__init__.py
index 8108bdeb50..0ee20a92df 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,5 +1,6 @@
import os
-from typing import Callable, Optional
+from collections.abc import Callable
+from typing import Optional
import pytest
from trainer.generic_utils import get_cuda
@@ -41,12 +42,7 @@ def get_tests_output_path():
return path
-def run_cli(command):
- exit_status = os.system(command)
- assert exit_status == 0, f" [!] command `{command}` failed."
-
-
-def run_main(main_func: Callable, args: Optional[list[str]] = None, expected_code: int = 0):
+def run_main(main_func: Callable, args: list[str] | None = None, expected_code: int = 0):
with pytest.raises(SystemExit) as exc_info:
main_func(args)
assert exc_info.value.code == expected_code
diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py
index f260af161e..975281c549 100644
--- a/tests/data_tests/test_loader.py
+++ b/tests/data_tests/test_loader.py
@@ -51,7 +51,7 @@
if not os.path.exists(c.data_path):
DATA_EXIST = False
-print(" > Dynamic data loader test: {}".format(DATA_EXIST))
+print(f" > Dynamic data loader test: {DATA_EXIST}")
def _create_dataloader(batch_size, r, bgs, dataset_config, start_by_longest=False, preprocess_samples=False):
diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py
new file mode 100644
index 0000000000..bd872c5b44
--- /dev/null
+++ b/tests/integration/__init__.py
@@ -0,0 +1,128 @@
+import json
+import shutil
+from pathlib import Path
+from typing import Any, TypeVar, Union
+
+import torch
+from trainer.io import get_last_checkpoint
+
+from tests import run_main
+from TTS.bin.synthesize import main as synthesize
+from TTS.bin.train_tts import main as train_tts
+from TTS.tts.configs.shared_configs import BaseTTSConfig
+from TTS.vc.configs.shared_configs import BaseVCConfig
+
+TEST_TTS_CONFIG = {
+ "batch_size": 8,
+ "eval_batch_size": 8,
+ "num_loader_workers": 0,
+ "num_eval_loader_workers": 0,
+ "text_cleaner": "english_cleaners",
+ "use_phonemes": True,
+ "phoneme_language": "en-us",
+ "run_eval": True,
+ "test_delay_epochs": -1,
+ "epochs": 1,
+ "print_step": 1,
+ "print_eval": True,
+ "test_sentences": ["Be a voice, not an echo."],
+}
+
+TEST_VC_CONFIG = {
+ "batch_size": 8,
+ "eval_batch_size": 8,
+ "num_loader_workers": 0,
+ "num_eval_loader_workers": 0,
+ "run_eval": True,
+ "test_delay_epochs": -1,
+ "epochs": 1,
+ "seq_len": 8192,
+ "eval_split_size": 1,
+ "print_step": 1,
+ "print_eval": True,
+ "data_path": "tests/data/ljspeech",
+}
+
+Config = TypeVar("Config", BaseTTSConfig, BaseVCConfig)
+
+
+def create_config(config_class: type[Config], **overrides: Any) -> Config:
+ base_config = TEST_TTS_CONFIG if issubclass(config_class, BaseTTSConfig) else TEST_VC_CONFIG
+ params = {**base_config, **overrides}
+ return config_class(**params)
+
+
+def run_tts_train(tmp_path: Path, config: BaseTTSConfig):
+ config_path = tmp_path / "test_model_config.json"
+ output_path = tmp_path / "train_outputs"
+
+ # For NeuralHMM and Overflow
+ parameter_path = tmp_path / "lj_parameters.pt"
+ torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path)
+ config.mel_statistics_parameter_path = parameter_path
+
+ config.audio.do_trim_silence = True
+ config.audio.trim_db = 60
+ config.save_json(config_path)
+
+ # train the model for one epoch
+ is_multi_speaker = config.use_speaker_embedding or config.use_d_vector_file
+ formatter = "ljspeech_test" if is_multi_speaker else "ljspeech"
+ command_train = [
+ "--config_path",
+ str(config_path),
+ "--coqpit.output_path",
+ str(output_path),
+ "--coqpit.phoneme_cache_path",
+ str(output_path / "phoneme_cache"),
+ "--coqpit.datasets.0.formatter",
+ formatter,
+ "--coqpit.datasets.0.meta_file_train",
+ "metadata.csv",
+ "--coqpit.datasets.0.meta_file_val",
+ "metadata.csv",
+ "--coqpit.datasets.0.path",
+ "tests/data/ljspeech",
+ "--coqpit.test_delay_epochs",
+ "0",
+ "--coqpit.datasets.0.meta_file_attn_mask",
+ "tests/data/ljspeech/metadata_attn_mask.txt",
+ ]
+ run_main(train_tts, command_train)
+
+ # Find latest folder
+ continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
+
+ # Inference using TTS API
+ continue_config_path = continue_path / "config.json"
+ continue_restore_path, _ = get_last_checkpoint(continue_path)
+ out_wav_path = tmp_path / "output.wav"
+
+ # Check integrity of the config
+ with continue_config_path.open() as f:
+ config_loaded = json.load(f)
+ assert config_loaded["characters"] is not None
+ assert config_loaded["output_path"] in str(continue_path)
+ assert config_loaded["test_delay_epochs"] == 0
+
+ inference_command = [
+ "--text",
+ "This is an example for the tests.",
+ "--config_path",
+ str(continue_config_path),
+ "--model_path",
+ str(continue_restore_path),
+ "--out_path",
+ str(out_wav_path),
+ ]
+ if config.use_speaker_embedding:
+ continue_speakers_path = continue_path / "speakers.json"
+ elif config.use_d_vector_file:
+ continue_speakers_path = config.d_vector_file
+ if is_multi_speaker:
+ inference_command.extend(["--speaker_idx", "ljspeech-1", "--speakers_file_path", str(continue_speakers_path)])
+ run_main(synthesize, inference_command)
+
+ # restore the model and continue training for one more epoch
+ run_main(train_tts, ["--continue_path", str(continue_path)])
+ shutil.rmtree(tmp_path)
diff --git a/tests/aux_tests/test_speaker_encoder_train.py b/tests/integration/test_speaker_encoder_train.py
similarity index 68%
rename from tests/aux_tests/test_speaker_encoder_train.py
rename to tests/integration/test_speaker_encoder_train.py
index 0e15db2ab0..ce817680b7 100644
--- a/tests/aux_tests/test_speaker_encoder_train.py
+++ b/tests/integration/test_speaker_encoder_train.py
@@ -1,6 +1,7 @@
import shutil
-from tests import get_device_id, run_cli
+from tests import run_main
+from TTS.bin.train_encoder import main
from TTS.config.shared_configs import BaseAudioConfig
from TTS.encoder.configs.speaker_encoder_config import SpeakerEncoderConfig
@@ -10,15 +11,21 @@ def test_train(tmp_path):
output_path = tmp_path / "train_outputs"
def run_test_train():
- command = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech_test "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- )
- run_cli(command)
+ command = [
+ "--config_path",
+ str(config_path),
+ "--coqpit.output_path",
+ str(output_path),
+ "--coqpit.datasets.0.formatter",
+ "ljspeech_test",
+ "--coqpit.datasets.0.meta_file_train",
+ "metadata.csv",
+ "--coqpit.datasets.0.meta_file_val",
+ "metadata.csv",
+ "--coqpit.datasets.0.path",
+ "tests/data/ljspeech",
+ ]
+ run_main(main, command)
config = SpeakerEncoderConfig(
batch_size=4,
@@ -47,10 +54,7 @@ def run_test_train():
continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
# restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --continue_path {continue_path} "
- )
- run_cli(command_train)
+ run_main(main, ["--continue_path", str(continue_path)])
shutil.rmtree(continue_path)
# test resnet speaker encoder
@@ -64,10 +68,7 @@ def run_test_train():
continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
# restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --continue_path {continue_path} "
- )
- run_cli(command_train)
+ run_main(main, ["--continue_path", str(continue_path)])
shutil.rmtree(continue_path)
# test model with ge2e loss function
diff --git a/tests/integration/test_train_tts.py b/tests/integration/test_train_tts.py
new file mode 100644
index 0000000000..d1e35ae450
--- /dev/null
+++ b/tests/integration/test_train_tts.py
@@ -0,0 +1,109 @@
+import pytest
+
+from tests.integration import create_config, run_tts_train
+from TTS.tts.configs.align_tts_config import AlignTTSConfig
+from TTS.tts.configs.delightful_tts_config import DelightfulTTSConfig
+from TTS.tts.configs.fast_pitch_config import FastPitchConfig
+from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
+from TTS.tts.configs.glow_tts_config import GlowTTSConfig
+from TTS.tts.configs.neuralhmm_tts_config import NeuralhmmTTSConfig
+from TTS.tts.configs.overflow_config import OverflowConfig
+from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig
+from TTS.tts.configs.tacotron2_config import Tacotron2Config
+from TTS.tts.configs.tacotron_config import TacotronConfig
+from TTS.tts.configs.vits_config import VitsConfig
+
+SPEAKER_ARGS = (
+ {},
+ {
+ "use_d_vector_file": True,
+ "d_vector_file": "tests/data/ljspeech/speakers.json",
+ "d_vector_dim": 256,
+ },
+ {
+ "use_speaker_embedding": True,
+ "num_speakers": 4,
+ },
+)
+SPEAKER_ARG_IDS = ["single", "dvector", "speaker_emb"]
+
+
+def test_train_align_tts(tmp_path):
+ config = create_config(AlignTTSConfig, use_phonemes=False)
+ run_tts_train(tmp_path, config)
+
+
+@pytest.mark.parametrize("speaker_args", SPEAKER_ARGS, ids=SPEAKER_ARG_IDS)
+def test_train_delightful_tts(tmp_path, speaker_args):
+ config = create_config(
+ DelightfulTTSConfig,
+ batch_size=2,
+ f0_cache_path=tmp_path / "f0_cache", # delightful f0 cache is incompatible with other models
+ binary_align_loss_alpha=0.0,
+ use_attn_priors=False,
+ **speaker_args,
+ )
+ run_tts_train(tmp_path, config)
+
+
+@pytest.mark.parametrize("speaker_args", SPEAKER_ARGS, ids=SPEAKER_ARG_IDS)
+def test_train_fast_pitch(tmp_path, speaker_args):
+ config = create_config(FastPitchConfig, f0_cache_path="tests/data/ljspeech/f0_cache", **speaker_args)
+ config.audio.signal_norm = False
+ config.audio.mel_fmax = 8000
+ config.audio.spec_gain = 1
+ config.audio.log_func = "np.log"
+ run_tts_train(tmp_path, config)
+
+
+@pytest.mark.parametrize("speaker_args", SPEAKER_ARGS, ids=SPEAKER_ARG_IDS)
+def test_train_fast_speech2(tmp_path, speaker_args):
+ config = create_config(
+ Fastspeech2Config,
+ f0_cache_path="tests/data/ljspeech/f0_cache",
+ energy_cache_path=tmp_path / "energy_cache",
+ **speaker_args,
+ )
+ config.audio.signal_norm = False
+ config.audio.mel_fmax = 8000
+ config.audio.spec_gain = 1
+ config.audio.log_func = "np.log"
+ run_tts_train(tmp_path, config)
+
+
+@pytest.mark.parametrize("speaker_args", SPEAKER_ARGS, ids=SPEAKER_ARG_IDS)
+def test_train_glow_tts(tmp_path, speaker_args):
+ config = create_config(GlowTTSConfig, batch_size=2, data_dep_init_steps=1, **speaker_args)
+ run_tts_train(tmp_path, config)
+
+
+def test_train_neuralhmm(tmp_path):
+ config = create_config(NeuralhmmTTSConfig, batch_size=3, eval_batch_size=3, max_sampling_time=50)
+ run_tts_train(tmp_path, config)
+
+
+def test_train_overflow(tmp_path):
+ config = create_config(OverflowConfig, batch_size=3, eval_batch_size=3, max_sampling_time=50)
+ run_tts_train(tmp_path, config)
+
+
+def test_train_speedy_speech(tmp_path):
+ config = create_config(SpeedySpeechConfig)
+ run_tts_train(tmp_path, config)
+
+
+def test_train_tacotron(tmp_path):
+ config = create_config(TacotronConfig, use_phonemes=False, r=5, max_decoder_steps=50)
+ run_tts_train(tmp_path, config)
+
+
+@pytest.mark.parametrize("speaker_args", SPEAKER_ARGS, ids=SPEAKER_ARG_IDS)
+def test_train_tacotron2(tmp_path, speaker_args):
+ config = create_config(Tacotron2Config, use_phonemes=False, r=5, max_decoder_steps=50, **speaker_args)
+ run_tts_train(tmp_path, config)
+
+
+@pytest.mark.parametrize("speaker_args", SPEAKER_ARGS, ids=SPEAKER_ARG_IDS)
+def test_train_vits(tmp_path, speaker_args):
+ config = create_config(VitsConfig, batch_size=2, eval_batch_size=2, **speaker_args)
+ run_tts_train(tmp_path, config)
diff --git a/tests/vocoder_tests/test_training.py b/tests/integration/test_train_vocoder.py
similarity index 100%
rename from tests/vocoder_tests/test_training.py
rename to tests/integration/test_train_vocoder.py
diff --git a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py b/tests/integration/test_vits_multilingual_speaker_emb_train.py
similarity index 75%
rename from tests/tts_tests/test_vits_multilingual_speaker_emb_train.py
rename to tests/integration/test_vits_multilingual_speaker_emb_train.py
index 189e6cfb4d..9b095935de 100644
--- a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py
+++ b/tests/integration/test_vits_multilingual_speaker_emb_train.py
@@ -3,7 +3,9 @@
from trainer.io import get_last_checkpoint
-from tests import get_device_id, run_cli
+from tests import run_main
+from TTS.bin.synthesize import main as synthesize
+from TTS.bin.train_tts import main as train_tts
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
@@ -73,12 +75,15 @@ def test_train(tmp_path):
config.save_json(config_path)
# train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.test_delay_epochs 0"
- )
- run_cli(command_train)
+ command_train = [
+ "--config_path",
+ str(config_path),
+ "--coqpit.output_path",
+ str(output_path),
+ "--coqpit.test_delay_epochs",
+ "0",
+ ]
+ run_main(train_tts, command_train)
# Find latest folder
continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
@@ -88,7 +93,7 @@ def test_train(tmp_path):
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = tmp_path / "output.wav"
speaker_id = "ljspeech"
- languae_id = "en"
+ language_id = "en"
continue_speakers_path = continue_path / "speakers.json"
continue_languages_path = continue_path / "language_ids.json"
@@ -100,12 +105,26 @@ def test_train(tmp_path):
assert config_loaded["test_delay_epochs"] == 0
# Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --language_ids_file_path {continue_languages_path} --language_idx {languae_id} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
+ inference_command = [
+ "--text",
+ "This is an example for the tests.",
+ "--speaker_idx",
+ speaker_id,
+ "--language_idx",
+ language_id,
+ "--speakers_file_path",
+ str(continue_speakers_path),
+ "--language_ids_file_path",
+ str(continue_languages_path),
+ "--config_path",
+ str(continue_config_path),
+ "--model_path",
+ str(continue_restore_path),
+ "--out_path",
+ str(out_wav_path),
+ ]
+ run_main(synthesize, inference_command)
# restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
+ run_main(train_tts, ["--continue_path", str(continue_path)])
shutil.rmtree(tmp_path)
diff --git a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py b/tests/integration/test_vits_multilingual_train-d_vectors.py
similarity index 77%
rename from tests/tts_tests/test_vits_multilingual_train-d_vectors.py
rename to tests/integration/test_vits_multilingual_train-d_vectors.py
index 8b8757422c..de0f6ed2b9 100644
--- a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py
+++ b/tests/integration/test_vits_multilingual_train-d_vectors.py
@@ -3,7 +3,9 @@
from trainer.io import get_last_checkpoint
-from tests import get_device_id, run_cli
+from tests import run_main
+from TTS.bin.synthesize import main as synthesize
+from TTS.bin.train_tts import main as train_tts
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
@@ -79,12 +81,15 @@ def test_train(tmp_path):
config.save_json(config_path)
# train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.test_delay_epochs 0"
- )
- run_cli(command_train)
+ command_train = [
+ "--config_path",
+ str(config_path),
+ "--coqpit.output_path",
+ str(output_path),
+ "--coqpit.test_delay_epochs",
+ "0",
+ ]
+ run_main(train_tts, command_train)
# Find latest folder
continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
@@ -94,7 +99,7 @@ def test_train(tmp_path):
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = tmp_path / "output.wav"
speaker_id = "ljspeech-1"
- languae_id = "en"
+ language_id = "en"
continue_speakers_path = config.d_vector_file
continue_languages_path = continue_path / "language_ids.json"
@@ -106,12 +111,26 @@ def test_train(tmp_path):
assert config_loaded["test_delay_epochs"] == 0
# Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --language_ids_file_path {continue_languages_path} --language_idx {languae_id} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
+ inference_command = [
+ "--text",
+ "This is an example for the tests.",
+ "--speaker_idx",
+ speaker_id,
+ "--language_idx",
+ language_id,
+ "--speakers_file_path",
+ str(continue_speakers_path),
+ "--language_ids_file_path",
+ str(continue_languages_path),
+ "--config_path",
+ str(continue_config_path),
+ "--model_path",
+ str(continue_restore_path),
+ "--out_path",
+ str(out_wav_path),
+ ]
+ run_main(synthesize, inference_command)
# restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
+ run_main(train_tts, ["--continue_path", str(continue_path)])
shutil.rmtree(tmp_path)
diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/integration/test_xtts_gpt_train.py
similarity index 100%
rename from tests/xtts_tests/test_xtts_gpt_train.py
rename to tests/integration/test_xtts_gpt_train.py
diff --git a/tests/text_tests/test_phonemizer.py b/tests/text_tests/test_phonemizer.py
index f9067530e6..370a541b97 100644
--- a/tests/text_tests/test_phonemizer.py
+++ b/tests/text_tests/test_phonemizer.py
@@ -240,12 +240,8 @@ def test_is_available(self):
class TestBN_Phonemizer(unittest.TestCase):
def setUp(self):
self.phonemizer = BN_Phonemizer()
- self._TEST_CASES = (
- "রাসূলুল্লাহ সাল্লাল্লাহু আলাইহি ওয়া সাল্লাম শিক্ষা দিয়েছেন যে, কেউ যদি কোন খারাপ কিছুর সম্মুখীন হয়, তখনও যেন"
- )
- self._EXPECTED = (
- "রাসূলুল্লাহ সাল্লাল্লাহু আলাইহি ওয়া সাল্লাম শিক্ষা দিয়েছেন যে কেউ যদি কোন খারাপ কিছুর সম্মুখীন হয় তখনও যেন।"
- )
+ self._TEST_CASES = "রাসূলুল্লাহ সাল্লাল্লাহু আলাইহি ওয়া সাল্লাম শিক্ষা দিয়েছেন যে, কেউ যদি কোন খারাপ কিছুর সম্মুখীন হয়, তখনও যেন"
+ self._EXPECTED = "রাসূলুল্লাহ সাল্লাল্লাহু আলাইহি ওয়া সাল্লাম শিক্ষা দিয়েছেন যে কেউ যদি কোন খারাপ কিছুর সম্মুখীন হয় তখনও যেন।"
def test_phonemize(self):
self.assertEqual(self.phonemizer.phonemize(self._TEST_CASES, separator=""), self._EXPECTED)
diff --git a/tests/text_tests/test_text_cleaners.py b/tests/text_tests/test_text_cleaners.py
index 25c169eddd..f5d342bb00 100644
--- a/tests/text_tests/test_text_cleaners.py
+++ b/tests/text_tests/test_text_cleaners.py
@@ -45,11 +45,11 @@ def test_normalize_unicode() -> None:
("na\u0303", "nã"),
("o\u0302u", "ôu"),
("n\u0303", "ñ"),
- ("\u4E2D\u56FD", "中国"),
+ ("\u4e2d\u56fd", "中国"),
("niño", "niño"),
("a\u0308", "ä"),
("\u3053\u3093\u306b\u3061\u306f", "こんにちは"),
- ("\u03B1\u03B2", "αβ"),
+ ("\u03b1\u03b2", "αβ"),
]
for arg, expect in test_cases:
assert normalize_unicode(arg) == expect
diff --git a/tests/tts_tests2/test_delightful_tts_layers.py b/tests/tts_tests/test_delightful_tts_layers.py
similarity index 100%
rename from tests/tts_tests2/test_delightful_tts_layers.py
rename to tests/tts_tests/test_delightful_tts_layers.py
diff --git a/tests/tts_tests2/test_feed_forward_layers.py b/tests/tts_tests/test_feed_forward_layers.py
similarity index 100%
rename from tests/tts_tests2/test_feed_forward_layers.py
rename to tests/tts_tests/test_feed_forward_layers.py
diff --git a/tests/tts_tests2/test_forward_tts.py b/tests/tts_tests/test_forward_tts.py
similarity index 100%
rename from tests/tts_tests2/test_forward_tts.py
rename to tests/tts_tests/test_forward_tts.py
diff --git a/tests/tts_tests2/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py
similarity index 95%
rename from tests/tts_tests2/test_glow_tts.py
rename to tests/tts_tests/test_glow_tts.py
index 3c7ac51556..c92063576f 100644
--- a/tests/tts_tests2/test_glow_tts.py
+++ b/tests/tts_tests/test_glow_tts.py
@@ -42,8 +42,8 @@ def _create_inputs(batch_size=8):
def _check_parameter_changes(model, model_ref):
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- count, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {count} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count += 1
@@ -107,7 +107,7 @@ def _test_forward(self, batch_size):
config = GlowTTSConfig(num_chars=32)
model = GlowTTS(config).to(device)
model.train()
- print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
+ print(f" > Num parameters for GlowTTS model:{count_parameters(model)}")
# inference encoder and decoder with MAS
y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
self.assertEqual(y["z"].shape, mel_spec.shape)
@@ -134,7 +134,7 @@ def _test_forward_with_d_vector(self, batch_size):
)
model = GlowTTS.init_from_config(config).to(device)
model.train()
- print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
+ print(f" > Num parameters for GlowTTS model:{count_parameters(model)}")
# inference encoder and decoder with MAS
y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"d_vectors": d_vector})
self.assertEqual(y["z"].shape, mel_spec.shape)
@@ -160,7 +160,7 @@ def _test_forward_with_speaker_id(self, batch_size):
)
model = GlowTTS.init_from_config(config).to(device)
model.train()
- print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
+ print(f" > Num parameters for GlowTTS model:{count_parameters(model)}")
# inference encoder and decoder with MAS
y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"speaker_ids": speaker_ids})
self.assertEqual(y["z"].shape, mel_spec.shape)
@@ -241,10 +241,10 @@ def _test_inference_with_MAS(self, batch_size):
# inference encoder and decoder with MAS
y = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths)
y2 = model.decoder_inference(mel_spec, mel_lengths)
- assert (
- y2["model_outputs"].shape == y["model_outputs"].shape
- ), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
- y["model_outputs"].shape, y2["model_outputs"].shape
+ assert y2["model_outputs"].shape == y["model_outputs"].shape, (
+ "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
+ y["model_outputs"].shape, y2["model_outputs"].shape
+ )
)
def test_inference_with_MAS(self):
@@ -261,7 +261,7 @@ def test_train_step(self):
# reference model to compare model weights
model_ref = GlowTTS(config).to(device)
model.train()
- print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
+ print(f" > Num parameters for GlowTTS model:{count_parameters(model)}")
# pass the state to ref model
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
count = 0
diff --git a/tests/tts_tests/test_losses.py b/tests/tts_tests/test_losses.py
index 794478dca3..2290e9a6cc 100644
--- a/tests/tts_tests/test_losses.py
+++ b/tests/tts_tests/test_losses.py
@@ -21,7 +21,7 @@ def test_in_out(self): # pylint: disable=no-self-use
dummy_target = T.zeros(4, 8, 128).float()
dummy_length = (T.ones(4) * 8).long()
output = layer(dummy_input, dummy_target, dummy_length)
- assert output.item() == 1.0, "1.0 vs {}".format(output.item())
+ assert output.item() == 1.0, f"1.0 vs {output.item()}"
# test if padded values of input makes any difference
dummy_input = T.ones(4, 8, 128).float()
@@ -29,14 +29,14 @@ def test_in_out(self): # pylint: disable=no-self-use
dummy_length = (T.arange(5, 9)).long()
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
- assert output.item() == 1.0, "1.0 vs {}".format(output.item())
+ assert output.item() == 1.0, f"1.0 vs {output.item()}"
dummy_input = T.rand(4, 8, 128).float()
dummy_target = dummy_input.detach()
dummy_length = (T.arange(5, 9)).long()
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
- assert output.item() == 0, "0 vs {}".format(output.item())
+ assert output.item() == 0, f"0 vs {output.item()}"
# seq_len_norm = True
# test input == target
@@ -52,7 +52,7 @@ def test_in_out(self): # pylint: disable=no-self-use
dummy_target = T.zeros(4, 8, 128).float()
dummy_length = (T.ones(4) * 8).long()
output = layer(dummy_input, dummy_target, dummy_length)
- assert output.item() == 1.0, "1.0 vs {}".format(output.item())
+ assert output.item() == 1.0, f"1.0 vs {output.item()}"
# test if padded values of input makes any difference
dummy_input = T.ones(4, 8, 128).float()
@@ -60,14 +60,14 @@ def test_in_out(self): # pylint: disable=no-self-use
dummy_length = (T.arange(5, 9)).long()
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
- assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
+ assert abs(output.item() - 1.0) < 1e-5, f"1.0 vs {output.item()}"
dummy_input = T.rand(4, 8, 128).float()
dummy_target = dummy_input.detach()
dummy_length = (T.arange(5, 9)).long()
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
- assert output.item() == 0, "0 vs {}".format(output.item())
+ assert output.item() == 0, f"0 vs {output.item()}"
class MSELossMaskedTests(unittest.TestCase):
@@ -85,7 +85,7 @@ def test_in_out(self): # pylint: disable=no-self-use
dummy_target = T.zeros(4, 8, 128).float()
dummy_length = (T.ones(4) * 8).long()
output = layer(dummy_input, dummy_target, dummy_length)
- assert output.item() == 1.0, "1.0 vs {}".format(output.item())
+ assert output.item() == 1.0, f"1.0 vs {output.item()}"
# test if padded values of input makes any difference
dummy_input = T.ones(4, 8, 128).float()
@@ -93,14 +93,14 @@ def test_in_out(self): # pylint: disable=no-self-use
dummy_length = (T.arange(5, 9)).long()
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
- assert output.item() == 1.0, "1.0 vs {}".format(output.item())
+ assert output.item() == 1.0, f"1.0 vs {output.item()}"
dummy_input = T.rand(4, 8, 128).float()
dummy_target = dummy_input.detach()
dummy_length = (T.arange(5, 9)).long()
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
- assert output.item() == 0, "0 vs {}".format(output.item())
+ assert output.item() == 0, f"0 vs {output.item()}"
# seq_len_norm = True
# test input == target
@@ -116,7 +116,7 @@ def test_in_out(self): # pylint: disable=no-self-use
dummy_target = T.zeros(4, 8, 128).float()
dummy_length = (T.ones(4) * 8).long()
output = layer(dummy_input, dummy_target, dummy_length)
- assert output.item() == 1.0, "1.0 vs {}".format(output.item())
+ assert output.item() == 1.0, f"1.0 vs {output.item()}"
# test if padded values of input makes any difference
dummy_input = T.ones(4, 8, 128).float()
@@ -124,14 +124,14 @@ def test_in_out(self): # pylint: disable=no-self-use
dummy_length = (T.arange(5, 9)).long()
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
- assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
+ assert abs(output.item() - 1.0) < 1e-5, f"1.0 vs {output.item()}"
dummy_input = T.rand(4, 8, 128).float()
dummy_target = dummy_input.detach()
dummy_length = (T.arange(5, 9)).long()
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
- assert output.item() == 0, "0 vs {}".format(output.item())
+ assert output.item() == 0, f"0 vs {output.item()}"
class SSIMLossTests(unittest.TestCase):
@@ -153,7 +153,7 @@ def test_in_out(self): # pylint: disable=no-self-use
dummy_length = (T.ones(4) * 58).long()
output = layer(dummy_input, dummy_target, dummy_length)
- assert output.item() >= 1.0, "0 vs {}".format(output.item())
+ assert output.item() >= 1.0, f"0 vs {output.item()}"
# test if padded values of input makes any difference
dummy_input = T.ones(4, 57, 128).float()
@@ -168,7 +168,7 @@ def test_in_out(self): # pylint: disable=no-self-use
dummy_length = (T.arange(54, 58)).long()
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
- assert output.item() == 0, "0 vs {}".format(output.item())
+ assert output.item() == 0, f"0 vs {output.item()}"
# seq_len_norm = True
# test input == target
@@ -184,7 +184,7 @@ def test_in_out(self): # pylint: disable=no-self-use
dummy_target = T.zeros(4, 57, 128).float()
dummy_length = (T.ones(4) * 8).long()
output = layer(dummy_input, dummy_target, dummy_length)
- assert output.item() == 1.0, "1.0 vs {}".format(output.item())
+ assert output.item() == 1.0, f"1.0 vs {output.item()}"
# test if padded values of input makes any difference
dummy_input = T.ones(4, 57, 128).float()
@@ -192,14 +192,14 @@ def test_in_out(self): # pylint: disable=no-self-use
dummy_length = (T.arange(54, 58)).long()
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
- assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
+ assert abs(output.item() - 1.0) < 1e-5, f"1.0 vs {output.item()}"
dummy_input = T.rand(4, 57, 128).float()
dummy_target = dummy_input.detach()
dummy_length = (T.arange(54, 58)).long()
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
- assert output.item() == 0, "0 vs {}".format(output.item())
+ assert output.item() == 0, f"0 vs {output.item()}"
class BCELossTest(unittest.TestCase):
diff --git a/tests/tts_tests/test_neuralhmm_tts_train.py b/tests/tts_tests/test_neuralhmm_tts_train.py
deleted file mode 100644
index f4b8d5cadd..0000000000
--- a/tests/tts_tests/test_neuralhmm_tts_train.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import json
-import shutil
-
-import torch
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.neuralhmm_tts_config import NeuralhmmTTSConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
- parameter_path = tmp_path / "lj_parameters.pt"
-
- torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path)
-
- config = NeuralhmmTTSConfig(
- batch_size=3,
- eval_batch_size=3,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="phoneme_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=output_path / "phoneme_cache",
- run_eval=True,
- test_delay_epochs=-1,
- mel_statistics_parameter_path=parameter_path,
- epochs=1,
- print_step=1,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- print_eval=True,
- max_sampling_time=50,
- )
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch when mel parameters exists
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.test_delay_epochs 0 "
- )
- run_cli(command_train)
-
- # train the model for one epoch when mel parameters have to be computed from the dataset
- if parameter_path.is_file():
- parameter_path.unlink()
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.test_delay_epochs 0 "
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(tmp_path)
diff --git a/tests/tts_tests/test_overflow_train.py b/tests/tts_tests/test_overflow_train.py
deleted file mode 100644
index e2dec3c899..0000000000
--- a/tests/tts_tests/test_overflow_train.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import json
-import shutil
-
-import torch
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.overflow_config import OverflowConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
- parameter_path = tmp_path / "lj_parameters.pt"
-
- torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path)
-
- config = OverflowConfig(
- batch_size=3,
- eval_batch_size=3,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="phoneme_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=output_path / "phoneme_cache",
- run_eval=True,
- test_delay_epochs=-1,
- mel_statistics_parameter_path=parameter_path,
- epochs=1,
- print_step=1,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- print_eval=True,
- max_sampling_time=50,
- )
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch when mel parameters exists
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.test_delay_epochs 0 "
- )
- run_cli(command_train)
-
- # train the model for one epoch when mel parameters have to be computed from the dataset
- if parameter_path.is_file():
- parameter_path.unlink()
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.test_delay_epochs 0 "
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(tmp_path)
diff --git a/tests/tts_tests/test_speedy_speech_train.py b/tests/tts_tests/test_speedy_speech_train.py
deleted file mode 100644
index 30efe38d9f..0000000000
--- a/tests/tts_tests/test_speedy_speech_train.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_speedy_speech_config.json"
- output_path = tmp_path / "train_outputs"
-
- config = SpeedySpeechConfig(
- batch_size=8,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=output_path / "phoneme_cache",
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- )
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs 0"
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example for it.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(tmp_path)
diff --git a/tests/tts_tests/test_tacotron2_d-vectors_train.py b/tests/tts_tests/test_tacotron2_d-vectors_train.py
deleted file mode 100644
index 191e0a19ee..0000000000
--- a/tests/tts_tests/test_tacotron2_d-vectors_train.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.tacotron2_config import Tacotron2Config
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- config = Tacotron2Config(
- r=5,
- batch_size=8,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=False,
- phoneme_language="en-us",
- phoneme_cache_path=output_path / "phoneme_cache",
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- use_speaker_embedding=False,
- use_d_vector_file=True,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- d_vector_file="tests/data/ljspeech/speakers.json",
- d_vector_dim=256,
- max_decoder_steps=50,
- )
-
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech_test "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.test_delay_epochs 0 "
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
- speaker_id = "ljspeech-1"
- continue_speakers_path = config.d_vector_file
-
- # Check integrity of the config
- with open(continue_config_path, "r", encoding="utf-8") as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(tmp_path)
diff --git a/tests/tts_tests/test_tacotron2_model.py b/tests/tts_tests/test_tacotron2_model.py
index 72b6bcd46b..72069bf943 100644
--- a/tests/tts_tests/test_tacotron2_model.py
+++ b/tests/tts_tests/test_tacotron2_model.py
@@ -72,8 +72,8 @@ def test_train_step(self): # pylint: disable=no-self-use
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- count, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {count} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count += 1
@@ -131,8 +131,8 @@ def test_train_step():
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- count, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {count} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count += 1
@@ -198,8 +198,8 @@ def test_train_step(self):
if name == "gst_layer.encoder.recurrence.weight_hh_l0":
# print(param.grad)
continue
- assert (param != param_ref).any(), "param {} {} with shape {} not updated!! \n{}\n{}".format(
- name, count, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {name} {count} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count += 1
@@ -254,8 +254,8 @@ def test_train_step(self):
if name == "gst_layer.encoder.recurrence.weight_hh_l0":
# print(param.grad)
continue
- assert (param != param_ref).any(), "param {} {} with shape {} not updated!! \n{}\n{}".format(
- name, count, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {name} {count} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count += 1
@@ -321,8 +321,8 @@ def test_train_step():
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- count, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {count} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count += 1
@@ -384,7 +384,7 @@ def test_train_step():
name, param = name_param
if name == "gst_layer.encoder.recurrence.weight_hh_l0":
continue
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- count, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {count} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count += 1
diff --git a/tests/tts_tests/test_tacotron2_speaker_emb_train.py b/tests/tts_tests/test_tacotron2_speaker_emb_train.py
deleted file mode 100644
index 2696edb1b6..0000000000
--- a/tests/tts_tests/test_tacotron2_speaker_emb_train.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.tacotron2_config import Tacotron2Config
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- config = Tacotron2Config(
- r=5,
- batch_size=8,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=False,
- phoneme_language="en-us",
- phoneme_cache_path=output_path / "phoneme_cache",
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- use_speaker_embedding=True,
- num_speakers=4,
- max_decoder_steps=50,
- )
-
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech_test "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.test_delay_epochs 0 "
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
- speaker_id = "ljspeech-1"
- continue_speakers_path = continue_path / "speakers.json"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(tmp_path)
diff --git a/tests/tts_tests/test_tacotron2_train.py b/tests/tts_tests/test_tacotron2_train.py
deleted file mode 100644
index f8667b6d02..0000000000
--- a/tests/tts_tests/test_tacotron2_train.py
+++ /dev/null
@@ -1,72 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.tacotron2_config import Tacotron2Config
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- config = Tacotron2Config(
- r=5,
- batch_size=8,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=False,
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- print_eval=True,
- max_decoder_steps=50,
- )
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.test_delay_epochs 0 "
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(tmp_path)
diff --git a/tests/tts_tests/test_tacotron_layers.py b/tests/tts_tests/test_tacotron_layers.py
index 43e72417c2..9521cfea26 100644
--- a/tests/tts_tests/test_tacotron_layers.py
+++ b/tests/tts_tests/test_tacotron_layers.py
@@ -67,8 +67,8 @@ def test_in_out():
output, alignment, stop_tokens = layer(dummy_input, dummy_memory, mask=None)
assert output.shape[0] == 4
- assert output.shape[1] == 80, "size not {}".format(output.shape[1])
- assert output.shape[2] == 2, "size not {}".format(output.shape[2])
+ assert output.shape[1] == 80, f"size not {output.shape[1]}"
+ assert output.shape[2] == 2, f"size not {output.shape[2]}"
assert stop_tokens.shape[0] == 4
diff --git a/tests/tts_tests/test_tacotron_model.py b/tests/tts_tests/test_tacotron_model.py
index 7ec3f0df1b..5f9af86e7e 100644
--- a/tests/tts_tests/test_tacotron_model.py
+++ b/tests/tts_tests/test_tacotron_model.py
@@ -51,7 +51,7 @@ def test_train_step():
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
model.train()
- print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
+ print(f" > Num parameters for Tacotron model:{count_parameters(model)}")
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
@@ -71,8 +71,8 @@ def test_train_step():
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- count, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {count} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count += 1
@@ -105,7 +105,7 @@ def test_train_step():
config.d_vector_dim = 55
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
model.train()
- print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
+ print(f" > Num parameters for Tacotron model:{count_parameters(model)}")
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
@@ -127,8 +127,8 @@ def test_train_step():
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- count, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {count} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count += 1
@@ -165,7 +165,7 @@ def test_train_step():
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
model.train()
# print(model)
- print(" > Num parameters for Tacotron GST model:%s" % (count_parameters(model)))
+ print(f" > Num parameters for Tacotron GST model:{count_parameters(model)}")
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
@@ -186,8 +186,8 @@ def test_train_step():
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- count, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {count} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count += 1
@@ -217,7 +217,7 @@ def test_train_step():
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
model.train()
# print(model)
- print(" > Num parameters for Tacotron GST model:%s" % (count_parameters(model)))
+ print(f" > Num parameters for Tacotron GST model:{count_parameters(model)}")
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
@@ -238,8 +238,8 @@ def test_train_step():
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- count, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {count} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count += 1
@@ -288,7 +288,7 @@ def test_train_step():
criterion = model.get_criterion()
optimizer = model.get_optimizer()
model.train()
- print(" > Num parameters for Tacotron with Capacitron VAE model:%s" % (count_parameters(model)))
+ print(f" > Num parameters for Tacotron with Capacitron VAE model:{count_parameters(model)}")
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
@@ -305,8 +305,8 @@ def test_train_step():
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- count, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {count} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count += 1
@@ -341,7 +341,7 @@ def test_train_step():
config.d_vector_dim = 55
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
model.train()
- print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
+ print(f" > Num parameters for Tacotron model:{count_parameters(model)}")
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
@@ -366,7 +366,7 @@ def test_train_step():
name, param = name_param
if name == "gst_layer.encoder.recurrence.weight_hh_l0":
continue
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- count, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {count} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count += 1
diff --git a/tests/tts_tests/test_tacotron_train.py b/tests/tts_tests/test_tacotron_train.py
deleted file mode 100644
index cc91b18c34..0000000000
--- a/tests/tts_tests/test_tacotron_train.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.tacotron_config import TacotronConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- config = TacotronConfig(
- batch_size=8,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=False,
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- print_eval=True,
- r=5,
- max_decoder_steps=50,
- )
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.test_delay_epochs 0"
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
-
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(tmp_path)
diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py
index c8a52e1c1b..790439ecb2 100644
--- a/tests/tts_tests/test_vits.py
+++ b/tests/tts_tests/test_vits.py
@@ -373,8 +373,8 @@ def _check_parameter_changes(model, model_ref):
name = item1[0]
param = item1[1]
param_ref = item2[1]
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- name, param.shape, param, param_ref
+ assert (param != param_ref).any(), (
+ f"param {name} with shape {param.shape} not updated!! \n{param}\n{param_ref}"
)
count = count + 1
diff --git a/tests/tts_tests/test_vits_d-vectors_train.py b/tests/tts_tests/test_vits_d-vectors_train.py
deleted file mode 100644
index b95e1deed3..0000000000
--- a/tests/tts_tests/test_vits_d-vectors_train.py
+++ /dev/null
@@ -1,61 +0,0 @@
-import shutil
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.vits_config import VitsConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- config = VitsConfig(
- batch_size=2,
- eval_batch_size=2,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=output_path / "phoneme_cache",
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- test_sentences=[
- ["Be a voice, not an echo.", "ljspeech-0"],
- ],
- )
- # set audio config
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
-
- # active multispeaker d-vec mode
- config.model_args.use_d_vector_file = True
- config.model_args.d_vector_file = ["tests/data/ljspeech/speakers.json"]
- config.model_args.d_vector_dim = 256
-
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs 0"
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(tmp_path)
diff --git a/tests/tts_tests/test_vits_speaker_emb_train.py b/tests/tts_tests/test_vits_speaker_emb_train.py
deleted file mode 100644
index 6678cca90c..0000000000
--- a/tests/tts_tests/test_vits_speaker_emb_train.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.vits_config import VitsConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- config = VitsConfig(
- batch_size=2,
- eval_batch_size=2,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=output_path / "phoneme_cache",
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- test_sentences=[
- ["Be a voice, not an echo.", "ljspeech-1"],
- ],
- )
- # set audio config
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
-
- # active multispeaker d-vec mode
- config.model_args.use_speaker_embedding = True
- config.model_args.use_d_vector_file = False
- config.model_args.d_vector_file = None
- config.model_args.d_vector_dim = 256
-
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech_test "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs 0"
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
- speaker_id = "ljspeech-1"
- continue_speakers_path = continue_path / "speakers.json"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(tmp_path)
diff --git a/tests/tts_tests/test_vits_train.py b/tests/tts_tests/test_vits_train.py
deleted file mode 100644
index e0f7a656b0..0000000000
--- a/tests/tts_tests/test_vits_train.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.vits_config import VitsConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- config = VitsConfig(
- batch_size=2,
- eval_batch_size=2,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=output_path / "phoneme_cache",
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- test_sentences=[
- ["Be a voice, not an echo."],
- ],
- )
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs 0"
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(tmp_path)
diff --git a/tests/tts_tests2/__init__.py b/tests/tts_tests2/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/tts_tests2/test_align_tts_train.py b/tests/tts_tests2/test_align_tts_train.py
deleted file mode 100644
index 1582f51fd4..0000000000
--- a/tests/tts_tests2/test_align_tts_train.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.align_tts_config import AlignTTSConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- config = AlignTTSConfig(
- batch_size=8,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=False,
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- )
-
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.test_delay_epochs 0 "
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(tmp_path)
diff --git a/tests/tts_tests2/test_delightful_tts_d-vectors_train.py b/tests/tts_tests2/test_delightful_tts_d-vectors_train.py
deleted file mode 100644
index 74d7a0a734..0000000000
--- a/tests/tts_tests2/test_delightful_tts_d-vectors_train.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig
-from TTS.tts.models.delightful_tts import DelightfulTtsArgs, VocoderConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- audio_config = DelightfulTtsAudioConfig()
- model_args = DelightfulTtsArgs(
- use_speaker_embedding=False, d_vector_dim=256, use_d_vector_file=True, speaker_embedding_channels=256
- )
-
- vocoder_config = VocoderConfig()
-
- config = DelightfulTTSConfig(
- model_args=model_args,
- audio=audio_config,
- vocoder=vocoder_config,
- batch_size=2,
- eval_batch_size=8,
- compute_f0=True,
- run_eval=True,
- test_delay_epochs=-1,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=tmp_path / "phoneme_cache",
- f0_cache_path=tmp_path / "f0_cache", # delightful f0 cache is incompatible with other models
- epochs=1,
- print_step=1,
- print_eval=True,
- binary_align_loss_alpha=0.0,
- use_attn_priors=False,
- test_sentences=[
- ["Be a voice, not an echo.", "ljspeech-0"],
- ],
- output_path=output_path,
- use_speaker_embedding=False,
- use_d_vector_file=True,
- d_vector_file="tests/data/ljspeech/speakers.json",
- d_vector_dim=256,
- speaker_embedding_channels=256,
- )
-
- # active multispeaker d-vec mode
- config.model_args.use_speaker_embedding = False
- config.model_args.use_d_vector_file = True
- config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
- config.model_args.d_vector_dim = 256
- config.save_json(config_path)
-
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs 0"
- )
-
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
- speaker_id = "ljspeech-1"
- continue_speakers_path = config.d_vector_file
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --config_path {continue_config_path} --speakers_file_path {continue_speakers_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(continue_path)
diff --git a/tests/tts_tests2/test_delightful_tts_emb_spk.py b/tests/tts_tests2/test_delightful_tts_emb_spk.py
deleted file mode 100644
index 68f790599e..0000000000
--- a/tests/tts_tests2/test_delightful_tts_emb_spk.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig
-from TTS.tts.models.delightful_tts import DelightfulTtsArgs, VocoderConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- audio_config = DelightfulTtsAudioConfig()
- model_args = DelightfulTtsArgs(use_speaker_embedding=False)
-
- vocoder_config = VocoderConfig()
-
- config = DelightfulTTSConfig(
- model_args=model_args,
- audio=audio_config,
- vocoder=vocoder_config,
- batch_size=2,
- eval_batch_size=8,
- compute_f0=True,
- run_eval=True,
- test_delay_epochs=-1,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=tmp_path / "phoneme_cache",
- f0_cache_path=tmp_path / "f0_cache", # delightful f0 cache is incompatible with other models
- epochs=1,
- print_step=1,
- print_eval=True,
- binary_align_loss_alpha=0.0,
- use_attn_priors=False,
- test_sentences=[
- ["Be a voice, not an echo.", "ljspeech"],
- ],
- output_path=output_path,
- num_speakers=4,
- use_speaker_embedding=True,
- )
-
- # active multispeaker d-vec mode
- config.model_args.use_speaker_embedding = True
- config.model_args.use_d_vector_file = False
- config.model_args.d_vector_file = None
- config.model_args.d_vector_dim = 256
- config.save_json(config_path)
-
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.dataset_name ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs 0"
- )
-
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
- speaker_id = "ljspeech"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(continue_path)
diff --git a/tests/tts_tests2/test_delightful_tts_train.py b/tests/tts_tests2/test_delightful_tts_train.py
deleted file mode 100644
index 4676ee4869..0000000000
--- a/tests/tts_tests2/test_delightful_tts_train.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.config.shared_configs import BaseAudioConfig
-from TTS.tts.configs.delightful_tts_config import DelightfulTTSConfig
-from TTS.tts.models.delightful_tts import DelightfulTtsArgs, DelightfulTtsAudioConfig, VocoderConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- audio_config = BaseAudioConfig(
- sample_rate=22050,
- do_trim_silence=True,
- trim_db=60.0,
- signal_norm=False,
- mel_fmin=0.0,
- mel_fmax=8000,
- spec_gain=1.0,
- log_func="np.log",
- ref_level_db=20,
- preemphasis=0.0,
- )
-
- audio_config = DelightfulTtsAudioConfig()
- model_args = DelightfulTtsArgs()
-
- vocoder_config = VocoderConfig()
-
- config = DelightfulTTSConfig(
- audio=audio_config,
- batch_size=2,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=tmp_path / "phoneme_cache",
- f0_cache_path=tmp_path / "f0_cache", # delightful f0 cache is incompatible with other models
- run_eval=True,
- test_delay_epochs=-1,
- binary_align_loss_alpha=0.0,
- epochs=1,
- print_step=1,
- use_attn_priors=False,
- print_eval=True,
- test_sentences=[
- ["Be a voice, not an echo."],
- ],
- use_speaker_embedding=False,
- )
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{'cpu'}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs -1"
- )
-
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == -1
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(continue_path)
diff --git a/tests/tts_tests2/test_fast_pitch_speaker_emb_train.py b/tests/tts_tests2/test_fast_pitch_speaker_emb_train.py
deleted file mode 100644
index 379e2f346b..0000000000
--- a/tests/tts_tests2/test_fast_pitch_speaker_emb_train.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.config.shared_configs import BaseAudioConfig
-from TTS.tts.configs.fast_pitch_config import FastPitchConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "fast_pitch_speaker_emb_config.json"
- output_path = tmp_path / "train_outputs"
-
- audio_config = BaseAudioConfig(
- sample_rate=22050,
- do_trim_silence=True,
- trim_db=60.0,
- signal_norm=False,
- mel_fmin=0.0,
- mel_fmax=8000,
- spec_gain=1.0,
- log_func="np.log",
- ref_level_db=20,
- preemphasis=0.0,
- )
-
- config = FastPitchConfig(
- audio=audio_config,
- batch_size=8,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=tmp_path / "phoneme_cache",
- f0_cache_path="tests/data/ljspeech/f0_cache/",
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- use_speaker_embedding=True,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- )
- config.audio.do_trim_silence = True
- config.use_speaker_embedding = True
- config.model_args.use_speaker_embedding = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech_test "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs 0"
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
- speaker_id = "ljspeech-1"
- continue_speakers_path = continue_path / "speakers.json"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(continue_path)
diff --git a/tests/tts_tests2/test_fast_pitch_train.py b/tests/tts_tests2/test_fast_pitch_train.py
deleted file mode 100644
index e0838a2049..0000000000
--- a/tests/tts_tests2/test_fast_pitch_train.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.config.shared_configs import BaseAudioConfig
-from TTS.tts.configs.fast_pitch_config import FastPitchConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- audio_config = BaseAudioConfig(
- sample_rate=22050,
- do_trim_silence=True,
- trim_db=60.0,
- signal_norm=False,
- mel_fmin=0.0,
- mel_fmax=8000,
- spec_gain=1.0,
- log_func="np.log",
- ref_level_db=20,
- preemphasis=0.0,
- )
-
- config = FastPitchConfig(
- audio=audio_config,
- batch_size=8,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=tmp_path / "phoneme_cache",
- f0_cache_path="tests/data/ljspeech/f0_cache/",
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- use_speaker_embedding=False,
- )
- config.audio.do_trim_silence = True
- config.use_speaker_embedding = False
- config.model_args.use_speaker_embedding = False
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs 0"
- )
-
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(continue_path)
diff --git a/tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py b/tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py
deleted file mode 100644
index 348729c6f4..0000000000
--- a/tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.config.shared_configs import BaseAudioConfig
-from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "fast_pitch_speaker_emb_config.json"
- output_path = tmp_path / "train_outputs"
-
- audio_config = BaseAudioConfig(
- sample_rate=22050,
- do_trim_silence=True,
- trim_db=60.0,
- signal_norm=False,
- mel_fmin=0.0,
- mel_fmax=8000,
- spec_gain=1.0,
- log_func="np.log",
- ref_level_db=20,
- preemphasis=0.0,
- )
-
- config = Fastspeech2Config(
- audio=audio_config,
- batch_size=8,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=tmp_path / "phoneme_cache",
- f0_cache_path="tests/data/ljspeech/f0_cache/",
- compute_f0=True,
- compute_energy=True,
- energy_cache_path=tmp_path / "energy_cache",
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- use_speaker_embedding=True,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- )
- config.audio.do_trim_silence = True
- config.use_speaker_embedding = True
- config.model_args.use_speaker_embedding = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech_test "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs 0"
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
- speaker_id = "ljspeech-1"
- continue_speakers_path = continue_path / "speakers.json"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(continue_path)
diff --git a/tests/tts_tests2/test_fastspeech_2_train.py b/tests/tts_tests2/test_fastspeech_2_train.py
deleted file mode 100644
index ab513ec827..0000000000
--- a/tests/tts_tests2/test_fastspeech_2_train.py
+++ /dev/null
@@ -1,96 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.config.shared_configs import BaseAudioConfig
-from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- audio_config = BaseAudioConfig(
- sample_rate=22050,
- do_trim_silence=True,
- trim_db=60.0,
- signal_norm=False,
- mel_fmin=0.0,
- mel_fmax=8000,
- spec_gain=1.0,
- log_func="np.log",
- ref_level_db=20,
- preemphasis=0.0,
- )
-
- config = Fastspeech2Config(
- audio=audio_config,
- batch_size=8,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=output_path / "phoneme_cache",
- f0_cache_path="tests/data/ljspeech/f0_cache/",
- compute_f0=True,
- compute_energy=True,
- energy_cache_path=output_path / "energy_cache",
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- use_speaker_embedding=False,
- )
- config.audio.do_trim_silence = True
- config.use_speaker_embedding = False
- config.model_args.use_speaker_embedding = False
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs 0"
- )
-
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(continue_path)
diff --git a/tests/tts_tests2/test_glow_tts_d-vectors_train.py b/tests/tts_tests2/test_glow_tts_d-vectors_train.py
deleted file mode 100644
index f03139ac77..0000000000
--- a/tests/tts_tests2/test_glow_tts_d-vectors_train.py
+++ /dev/null
@@ -1,80 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.glow_tts_config import GlowTTSConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- config = GlowTTSConfig(
- batch_size=2,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=output_path / "phoneme_cache",
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- data_dep_init_steps=1.0,
- use_speaker_embedding=False,
- use_d_vector_file=True,
- d_vector_file="tests/data/ljspeech/speakers.json",
- d_vector_dim=256,
- )
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech_test "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs 0"
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
- speaker_id = "ljspeech-1"
- continue_speakers_path = config.d_vector_file
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(continue_path)
diff --git a/tests/tts_tests2/test_glow_tts_speaker_emb_train.py b/tests/tts_tests2/test_glow_tts_speaker_emb_train.py
deleted file mode 100644
index b9fe93a2fa..0000000000
--- a/tests/tts_tests2/test_glow_tts_speaker_emb_train.py
+++ /dev/null
@@ -1,77 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.glow_tts_config import GlowTTSConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- config = GlowTTSConfig(
- batch_size=2,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=tmp_path / "phoneme_cache",
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- data_dep_init_steps=1.0,
- use_speaker_embedding=True,
- )
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech_test "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs 0"
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
- speaker_id = "ljspeech-1"
- continue_speakers_path = continue_path / "speakers.json"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(continue_path)
diff --git a/tests/tts_tests2/test_glow_tts_train.py b/tests/tts_tests2/test_glow_tts_train.py
deleted file mode 100644
index 3f1bf3a794..0000000000
--- a/tests/tts_tests2/test_glow_tts_train.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import json
-import shutil
-
-from trainer.io import get_last_checkpoint
-
-from tests import get_device_id, run_cli
-from TTS.tts.configs.glow_tts_config import GlowTTSConfig
-
-
-def test_train(tmp_path):
- config_path = tmp_path / "test_model_config.json"
- output_path = tmp_path / "train_outputs"
-
- config = GlowTTSConfig(
- batch_size=2,
- eval_batch_size=8,
- num_loader_workers=0,
- num_eval_loader_workers=0,
- text_cleaner="english_cleaners",
- use_phonemes=True,
- phoneme_language="en-us",
- phoneme_cache_path=tmp_path / "phoneme_cache",
- run_eval=True,
- test_delay_epochs=-1,
- epochs=1,
- print_step=1,
- print_eval=True,
- test_sentences=[
- "Be a voice, not an echo.",
- ],
- data_dep_init_steps=1.0,
- )
- config.audio.do_trim_silence = True
- config.audio.trim_db = 60
- config.save_json(config_path)
-
- # train the model for one epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
- f"--coqpit.output_path {output_path} "
- "--coqpit.datasets.0.formatter ljspeech "
- "--coqpit.datasets.0.meta_file_train metadata.csv "
- "--coqpit.datasets.0.meta_file_val metadata.csv "
- "--coqpit.datasets.0.path tests/data/ljspeech "
- "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
- "--coqpit.test_delay_epochs 0"
- )
- run_cli(command_train)
-
- # Find latest folder
- continue_path = max(output_path.iterdir(), key=lambda p: p.stat().st_mtime)
-
- # Inference using TTS API
- continue_config_path = continue_path / "config.json"
- continue_restore_path, _ = get_last_checkpoint(continue_path)
- out_wav_path = tmp_path / "output.wav"
-
- # Check integrity of the config
- with continue_config_path.open() as f:
- config_loaded = json.load(f)
- assert config_loaded["characters"] is not None
- assert config_loaded["output_path"] in str(continue_path)
- assert config_loaded["test_delay_epochs"] == 0
-
- # Load the model and run inference
- inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
- run_cli(inference_command)
-
- # restore the model and continue training for one more epoch
- command_train = (
- f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
- )
- run_cli(command_train)
- shutil.rmtree(continue_path)
diff --git a/tests/vc_tests/test_freevc.py b/tests/vc_tests/test_freevc.py
index fe07b2723c..784e32a68d 100644
--- a/tests/vc_tests/test_freevc.py
+++ b/tests/vc_tests/test_freevc.py
@@ -55,7 +55,7 @@ def _test_forward(self, batch_size):
config = FreeVCConfig()
model = FreeVC(config).to(device)
model.train()
- print(" > Num parameters for FreeVC model:%s" % (count_parameters(model)))
+ print(f" > Num parameters for FreeVC model:{count_parameters(model)}")
mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size)
@@ -80,9 +80,9 @@ def _test_inference(self, batch_size):
wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long)
output_wav = model.inference(wavlm_vec, None, mel, wavlm_vec_lengths)
- assert (
- output_wav.shape[-1] // config.audio.hop_length == wavlm_vec.shape[-1]
- ), f"{output_wav.shape[-1] // config.audio.hop_length} != {wavlm_vec.shape}"
+ assert output_wav.shape[-1] // config.audio.hop_length == wavlm_vec.shape[-1], (
+ f"{output_wav.shape[-1] // config.audio.hop_length} != {wavlm_vec.shape}"
+ )
def test_inference(self):
self._test_inference(1)
@@ -95,9 +95,9 @@ def test_voice_conversion(self):
source_wav, target_wav = self._create_inputs_inference()
output_wav = model.voice_conversion(source_wav, target_wav)
- assert (
- output_wav.shape[0] == source_wav.shape[0] - source_wav.shape[0] % config.audio.hop_length
- ), f"{output_wav.shape} != {source_wav.shape}, {config.audio.hop_length}"
+ assert output_wav.shape[0] == source_wav.shape[0] - source_wav.shape[0] % config.audio.hop_length, (
+ f"{output_wav.shape} != {source_wav.shape}, {config.audio.hop_length}"
+ )
def test_train_step(self): ...
diff --git a/tests/vc_tests/test_openvoice.py b/tests/vc_tests/test_openvoice.py
index c9f7ae3931..703873ea47 100644
--- a/tests/vc_tests/test_openvoice.py
+++ b/tests/vc_tests/test_openvoice.py
@@ -16,7 +16,6 @@
class TestOpenVoice(unittest.TestCase):
-
@staticmethod
def _create_inputs_inference():
source_wav = torch.rand(16100)
@@ -37,6 +36,6 @@ def test_voice_conversion(self):
source_wav, target_wav = self._create_inputs_inference()
output_wav = model.voice_conversion(source_wav, target_wav)
- assert (
- output_wav.shape[0] == source_wav.shape[0] - source_wav.shape[0] % config.audio.hop_length
- ), f"{output_wav.shape} != {source_wav.shape}"
+ assert output_wav.shape[0] == source_wav.shape[0] - source_wav.shape[0] % config.audio.hop_length, (
+ f"{output_wav.shape} != {source_wav.shape}"
+ )
diff --git a/tests/vocoder_tests/test_wavegrad.py b/tests/vocoder_tests/test_wavegrad.py
index 7530bec426..d1d3610b70 100644
--- a/tests/vocoder_tests/test_wavegrad.py
+++ b/tests/vocoder_tests/test_wavegrad.py
@@ -47,6 +47,4 @@ def test_train_step():
for i, (param, param_ref) in enumerate(zip(model.parameters(), model_ref.parameters())):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- i, param.shape, param, param_ref
- )
+ assert (param != param_ref).any(), f"param {i} with shape {param.shape} not updated!! \n{param}\n{param_ref}"