Skip to content

Commit ff2cd5c

Browse files
authored
Merge pull request #49 from idiap/vc-refactors
VC-related refactors and fixes
2 parents 98c0f86 + 4bd3df2 commit ff2cd5c

18 files changed

+76
-184
lines changed

TTS/bin/train_encoder.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,6 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
161161
loader_time = time.time() - end_time
162162
global_step += 1
163163

164-
# setup lr
165-
if c.lr_decay:
166-
scheduler.step()
167164
optimizer.zero_grad()
168165

169166
# dispatch data to GPU
@@ -182,6 +179,10 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
182179
grad_norm, _ = check_update(model, c.grad_clip)
183180
optimizer.step()
184181

182+
# setup lr
183+
if c.lr_decay:
184+
scheduler.step()
185+
185186
step_time = time.time() - start_time
186187
epoch_time += step_time
187188

TTS/encoder/dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
logger.info(" | Number of instances: %d", len(self.items))
5656
logger.info(" | Sequence length: %d", self.seq_len)
5757
logger.info(" | Number of classes: %d", len(self.classes))
58-
logger.info(" | Classes: %d", self.classes)
58+
logger.info(" | Classes: %s", self.classes)
5959

6060
def load_wav(self, filename):
6161
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)

TTS/model.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import os
12
from abc import abstractmethod
2-
from typing import Dict
3+
from typing import Any, Union
34

45
import torch
56
from coqpit import Coqpit
@@ -16,15 +17,15 @@ class BaseTrainerModel(TrainerModel):
1617

1718
@staticmethod
1819
@abstractmethod
19-
def init_from_config(config: Coqpit):
20+
def init_from_config(config: Coqpit) -> "BaseTrainerModel":
2021
"""Init the model and all its attributes from the given config.
2122
2223
Override this depending on your model.
2324
"""
2425
...
2526

2627
@abstractmethod
27-
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
28+
def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]:
2829
"""Forward pass for inference.
2930
3031
It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
@@ -45,13 +46,18 @@ def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
4546

4647
@abstractmethod
4748
def load_checkpoint(
48-
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False
49+
self,
50+
config: Coqpit,
51+
checkpoint_path: Union[str, os.PathLike[Any]],
52+
eval: bool = False,
53+
strict: bool = True,
54+
cache: bool = False,
4955
) -> None:
50-
"""Load a model checkpoint gile and get ready for training or inference.
56+
"""Load a model checkpoint file and get ready for training or inference.
5157
5258
Args:
5359
config (Coqpit): Model configuration.
54-
checkpoint_path (str): Path to the model checkpoint file.
60+
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
5561
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
5662
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
5763
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.

TTS/tts/layers/glow_tts/transformer.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch.nn import functional as F
66

77
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
8+
from TTS.tts.utils.helpers import convert_pad_shape
89

910

1011
class RelativePositionMultiHeadAttention(nn.Module):
@@ -300,7 +301,7 @@ def _causal_padding(self, x):
300301
pad_l = self.kernel_size - 1
301302
pad_r = 0
302303
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
303-
x = F.pad(x, self._pad_shape(padding))
304+
x = F.pad(x, convert_pad_shape(padding))
304305
return x
305306

306307
def _same_padding(self, x):
@@ -309,15 +310,9 @@ def _same_padding(self, x):
309310
pad_l = (self.kernel_size - 1) // 2
310311
pad_r = self.kernel_size // 2
311312
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
312-
x = F.pad(x, self._pad_shape(padding))
313+
x = F.pad(x, convert_pad_shape(padding))
313314
return x
314315

315-
@staticmethod
316-
def _pad_shape(padding):
317-
l = padding[::-1]
318-
pad_shape = [item for sublist in l for item in sublist]
319-
return pad_shape
320-
321316

322317
class RelativePositionTransformer(nn.Module):
323318
"""Transformer with Relative Potional Encoding.

TTS/tts/layers/losses.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def forward(self, att_ws, ilens, olens):
255255

256256
@staticmethod
257257
def _make_ga_mask(ilen, olen, sigma):
258-
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
258+
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen), indexing="ij")
259259
grid_x, grid_y = grid_x.float(), grid_y.float()
260260
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)))
261261

TTS/tts/layers/overflow/neural_hmm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ def forward(self, inputs, inputs_len, mels, mel_lens):
128128
# Get mean, std and transition vector from decoder for this timestep
129129
# Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop
130130
if self.use_grad_checkpointing and self.training:
131-
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs)
131+
# TODO: use_reentrant=False is recommended
132+
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs, use_reentrant=True)
132133
else:
133134
mean, std, transition_vector = self.output_net(h_memory, inputs)
134135

TTS/tts/layers/vits/networks.py

-16
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,6 @@
1010
LRELU_SLOPE = 0.1
1111

1212

13-
def convert_pad_shape(pad_shape):
14-
l = pad_shape[::-1]
15-
pad_shape = [item for sublist in l for item in sublist]
16-
return pad_shape
17-
18-
19-
def init_weights(m, mean=0.0, std=0.01):
20-
classname = m.__class__.__name__
21-
if classname.find("Conv") != -1:
22-
m.weight.data.normal_(mean, std)
23-
24-
25-
def get_padding(kernel_size, dilation=1):
26-
return int((kernel_size * dilation - dilation) / 2)
27-
28-
2913
class TextEncoder(nn.Module):
3014
def __init__(
3115
self,

TTS/tts/layers/xtts/hifigan_decoder.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,13 @@
99
from torch.nn.utils.parametrize import remove_parametrizations
1010

1111
from TTS.utils.io import load_fsspec
12+
from TTS.vocoder.models.hifigan_generator import get_padding
1213

1314
logger = logging.getLogger(__name__)
1415

1516
LRELU_SLOPE = 0.1
1617

1718

18-
def get_padding(k, d):
19-
return int((k * d - d) / 2)
20-
21-
2219
class ResBlock1(torch.nn.Module):
2320
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
2421

TTS/tts/models/base_tts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def get_aux_input_from_test_sentences(self, sentence_info):
144144
if speaker_name is None:
145145
d_vector = self.speaker_manager.get_random_embedding()
146146
else:
147-
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
147+
d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
148148
elif config.use_speaker_embedding:
149149
if speaker_name is None:
150150
speaker_id = self.speaker_manager.get_random_id()

TTS/tts/models/delightful_tts.py

-6
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,6 @@ def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor:
8888
return out_padded
8989

9090

91-
def init_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01):
92-
classname = m.__class__.__name__
93-
if classname.find("Conv") != -1:
94-
m.weight.data.normal_(mean, std)
95-
96-
9791
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
9892
return torch.ceil(lens / stride).int()
9993

TTS/tts/utils/helpers.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,9 @@ def average_over_durations(values, durs):
145145
return avg
146146

147147

148-
def convert_pad_shape(pad_shape):
148+
def convert_pad_shape(pad_shape: list[list]) -> list:
149149
l = pad_shape[::-1]
150-
pad_shape = [item for sublist in l for item in sublist]
151-
return pad_shape
150+
return [item for sublist in l for item in sublist]
152151

153152

154153
def generate_path(duration, mask):

TTS/vc/models/base_vc.py

+24-22
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33
import random
4-
from typing import Dict, List, Tuple, Union
4+
from typing import Any, Optional, Union
55

66
import torch
77
import torch.distributed as dist
@@ -10,6 +10,7 @@
1010
from torch.utils.data import DataLoader
1111
from torch.utils.data.sampler import WeightedRandomSampler
1212
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
13+
from trainer.trainer import Trainer
1314

1415
from TTS.model import BaseTrainerModel
1516
from TTS.tts.datasets.dataset import TTSDataset
@@ -18,6 +19,7 @@
1819
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
1920
from TTS.tts.utils.synthesis import synthesis
2021
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
22+
from TTS.utils.audio.processor import AudioProcessor
2123

2224
# pylint: skip-file
2325

@@ -35,18 +37,18 @@ class BaseVC(BaseTrainerModel):
3537
def __init__(
3638
self,
3739
config: Coqpit,
38-
ap: "AudioProcessor",
39-
speaker_manager: SpeakerManager = None,
40-
language_manager: LanguageManager = None,
41-
):
40+
ap: AudioProcessor,
41+
speaker_manager: Optional[SpeakerManager] = None,
42+
language_manager: Optional[LanguageManager] = None,
43+
) -> None:
4244
super().__init__()
4345
self.config = config
4446
self.ap = ap
4547
self.speaker_manager = speaker_manager
4648
self.language_manager = language_manager
4749
self._set_model_args(config)
4850

49-
def _set_model_args(self, config: Coqpit):
51+
def _set_model_args(self, config: Coqpit) -> None:
5052
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
5153
5254
`ModelArgs` has all the fields reuqired to initialize the model architecture.
@@ -67,7 +69,7 @@ def _set_model_args(self, config: Coqpit):
6769
else:
6870
raise ValueError("config must be either a *Config or *Args")
6971

70-
def init_multispeaker(self, config: Coqpit, data: List = None):
72+
def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
7173
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
7274
`in_channels` size of the connected layers.
7375
@@ -100,11 +102,11 @@ def init_multispeaker(self, config: Coqpit, data: List = None):
100102
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
101103
self.speaker_embedding.weight.data.normal_(0, 0.3)
102104

103-
def get_aux_input(self, **kwargs) -> Dict:
105+
def get_aux_input(self, **kwargs: Any) -> dict[str, Any]:
104106
"""Prepare and return `aux_input` used by `forward()`"""
105107
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
106108

107-
def get_aux_input_from_test_sentences(self, sentence_info):
109+
def get_aux_input_from_test_sentences(self, sentence_info: Union[str, list[str]]) -> dict[str, Any]:
108110
if hasattr(self.config, "model_args"):
109111
config = self.config.model_args
110112
else:
@@ -132,7 +134,7 @@ def get_aux_input_from_test_sentences(self, sentence_info):
132134
if speaker_name is None:
133135
d_vector = self.speaker_manager.get_random_embedding()
134136
else:
135-
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
137+
d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
136138
elif config.use_speaker_embedding:
137139
if speaker_name is None:
138140
speaker_id = self.speaker_manager.get_random_id()
@@ -151,16 +153,16 @@ def get_aux_input_from_test_sentences(self, sentence_info):
151153
"language_id": language_id,
152154
}
153155

154-
def format_batch(self, batch: Dict) -> Dict:
156+
def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
155157
"""Generic batch formatting for `VCDataset`.
156158
157159
You must override this if you use a custom dataset.
158160
159161
Args:
160-
batch (Dict): [description]
162+
batch (dict): [description]
161163
162164
Returns:
163-
Dict: [description]
165+
dict: [description]
164166
"""
165167
# setup input batch
166168
text_input = batch["token_id"]
@@ -230,7 +232,7 @@ def format_batch(self, batch: Dict) -> Dict:
230232
"audio_unique_names": batch["audio_unique_names"],
231233
}
232234

233-
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
235+
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus: int = 1):
234236
weights = None
235237
data_items = dataset.samples
236238

@@ -271,12 +273,12 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
271273
def get_data_loader(
272274
self,
273275
config: Coqpit,
274-
assets: Dict,
276+
assets: dict,
275277
is_eval: bool,
276-
samples: Union[List[Dict], List[List]],
278+
samples: Union[list[dict], list[list]],
277279
verbose: bool,
278280
num_gpus: int,
279-
rank: int = None,
281+
rank: Optional[int] = None,
280282
) -> "DataLoader":
281283
if is_eval and not config.run_eval:
282284
loader = None
@@ -352,9 +354,9 @@ def get_data_loader(
352354

353355
def _get_test_aux_input(
354356
self,
355-
) -> Dict:
357+
) -> dict[str, Any]:
356358
d_vector = None
357-
if self.config.use_d_vector_file:
359+
if self.speaker_manager is not None and self.config.use_d_vector_file:
358360
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
359361
d_vector = (random.sample(sorted(d_vector), 1),)
360362

@@ -369,7 +371,7 @@ def _get_test_aux_input(
369371
}
370372
return aux_inputs
371373

372-
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
374+
def test_run(self, assets: dict) -> tuple[dict, dict]:
373375
"""Generic test run for `vc` models used by `Trainer`.
374376
375377
You can override this for a different behaviour.
@@ -378,7 +380,7 @@ def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
378380
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
379381
380382
Returns:
381-
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
383+
tuple[dict, dict]: Test figures and audios to be projected to Tensorboard.
382384
"""
383385
logger.info("Synthesizing test sentences.")
384386
test_audios = {}
@@ -409,7 +411,7 @@ def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
409411
)
410412
return test_figures, test_audios
411413

412-
def on_init_start(self, trainer):
414+
def on_init_start(self, trainer: Trainer) -> None:
413415
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
414416
if self.speaker_manager is not None:
415417
output_path = os.path.join(trainer.output_path, "speakers.pth")

TTS/vc/models/freevc.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414

1515
import TTS.vc.modules.freevc.commons as commons
1616
import TTS.vc.modules.freevc.modules as modules
17+
from TTS.tts.utils.helpers import sequence_mask
1718
from TTS.tts.utils.speakers import SpeakerManager
1819
from TTS.utils.io import load_fsspec
1920
from TTS.vc.configs.freevc_config import FreeVCConfig
2021
from TTS.vc.models.base_vc import BaseVC
21-
from TTS.vc.modules.freevc.commons import get_padding, init_weights
22+
from TTS.vc.modules.freevc.commons import init_weights
2223
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
2324
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
2425
from TTS.vc.modules.freevc.wavlm import get_wavlm
26+
from TTS.vocoder.models.hifigan_generator import get_padding
2527

2628
logger = logging.getLogger(__name__)
2729

@@ -80,7 +82,7 @@ def __init__(
8082
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
8183

8284
def forward(self, x, x_lengths, g=None):
83-
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
85+
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
8486
x = self.pre(x) * x_mask
8587
x = self.enc(x, x_mask, g=g)
8688
stats = self.proj(x) * x_mask

0 commit comments

Comments
 (0)