1
1
import logging
2
2
import os
3
3
import random
4
- from typing import Dict , List , Tuple , Union
4
+ from typing import Any , Optional , Union
5
5
6
6
import torch
7
7
import torch .distributed as dist
10
10
from torch .utils .data import DataLoader
11
11
from torch .utils .data .sampler import WeightedRandomSampler
12
12
from trainer .torch import DistributedSampler , DistributedSamplerWrapper
13
+ from trainer .trainer import Trainer
13
14
14
15
from TTS .model import BaseTrainerModel
15
16
from TTS .tts .datasets .dataset import TTSDataset
18
19
from TTS .tts .utils .speakers import SpeakerManager , get_speaker_balancer_weights
19
20
from TTS .tts .utils .synthesis import synthesis
20
21
from TTS .tts .utils .visual import plot_alignment , plot_spectrogram
22
+ from TTS .utils .audio .processor import AudioProcessor
21
23
22
24
# pylint: skip-file
23
25
@@ -35,18 +37,18 @@ class BaseVC(BaseTrainerModel):
35
37
def __init__ (
36
38
self ,
37
39
config : Coqpit ,
38
- ap : " AudioProcessor" ,
39
- speaker_manager : SpeakerManager = None ,
40
- language_manager : LanguageManager = None ,
41
- ):
40
+ ap : AudioProcessor ,
41
+ speaker_manager : Optional [ SpeakerManager ] = None ,
42
+ language_manager : Optional [ LanguageManager ] = None ,
43
+ ) -> None :
42
44
super ().__init__ ()
43
45
self .config = config
44
46
self .ap = ap
45
47
self .speaker_manager = speaker_manager
46
48
self .language_manager = language_manager
47
49
self ._set_model_args (config )
48
50
49
- def _set_model_args (self , config : Coqpit ):
51
+ def _set_model_args (self , config : Coqpit ) -> None :
50
52
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
51
53
52
54
`ModelArgs` has all the fields reuqired to initialize the model architecture.
@@ -67,7 +69,7 @@ def _set_model_args(self, config: Coqpit):
67
69
else :
68
70
raise ValueError ("config must be either a *Config or *Args" )
69
71
70
- def init_multispeaker (self , config : Coqpit , data : List = None ):
72
+ def init_multispeaker (self , config : Coqpit , data : Optional [ list [ Any ]] = None ) -> None :
71
73
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
72
74
`in_channels` size of the connected layers.
73
75
@@ -100,11 +102,11 @@ def init_multispeaker(self, config: Coqpit, data: List = None):
100
102
self .speaker_embedding = nn .Embedding (self .num_speakers , self .embedded_speaker_dim )
101
103
self .speaker_embedding .weight .data .normal_ (0 , 0.3 )
102
104
103
- def get_aux_input (self , ** kwargs ) -> Dict :
105
+ def get_aux_input (self , ** kwargs : Any ) -> dict [ str , Any ] :
104
106
"""Prepare and return `aux_input` used by `forward()`"""
105
107
return {"speaker_id" : None , "style_wav" : None , "d_vector" : None , "language_id" : None }
106
108
107
- def get_aux_input_from_test_sentences (self , sentence_info ) :
109
+ def get_aux_input_from_test_sentences (self , sentence_info : Union [ str , list [ str ]]) -> dict [ str , Any ] :
108
110
if hasattr (self .config , "model_args" ):
109
111
config = self .config .model_args
110
112
else :
@@ -132,7 +134,7 @@ def get_aux_input_from_test_sentences(self, sentence_info):
132
134
if speaker_name is None :
133
135
d_vector = self .speaker_manager .get_random_embedding ()
134
136
else :
135
- d_vector = self .speaker_manager .get_d_vector_by_name (speaker_name )
137
+ d_vector = self .speaker_manager .get_mean_embedding (speaker_name )
136
138
elif config .use_speaker_embedding :
137
139
if speaker_name is None :
138
140
speaker_id = self .speaker_manager .get_random_id ()
@@ -151,16 +153,16 @@ def get_aux_input_from_test_sentences(self, sentence_info):
151
153
"language_id" : language_id ,
152
154
}
153
155
154
- def format_batch (self , batch : Dict ) -> Dict :
156
+ def format_batch (self , batch : dict [ str , Any ] ) -> dict [ str , Any ] :
155
157
"""Generic batch formatting for `VCDataset`.
156
158
157
159
You must override this if you use a custom dataset.
158
160
159
161
Args:
160
- batch (Dict ): [description]
162
+ batch (dict ): [description]
161
163
162
164
Returns:
163
- Dict : [description]
165
+ dict : [description]
164
166
"""
165
167
# setup input batch
166
168
text_input = batch ["token_id" ]
@@ -230,7 +232,7 @@ def format_batch(self, batch: Dict) -> Dict:
230
232
"audio_unique_names" : batch ["audio_unique_names" ],
231
233
}
232
234
233
- def get_sampler (self , config : Coqpit , dataset : TTSDataset , num_gpus = 1 ):
235
+ def get_sampler (self , config : Coqpit , dataset : TTSDataset , num_gpus : int = 1 ):
234
236
weights = None
235
237
data_items = dataset .samples
236
238
@@ -271,12 +273,12 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
271
273
def get_data_loader (
272
274
self ,
273
275
config : Coqpit ,
274
- assets : Dict ,
276
+ assets : dict ,
275
277
is_eval : bool ,
276
- samples : Union [List [ Dict ], List [ List ]],
278
+ samples : Union [list [ dict ], list [ list ]],
277
279
verbose : bool ,
278
280
num_gpus : int ,
279
- rank : int = None ,
281
+ rank : Optional [ int ] = None ,
280
282
) -> "DataLoader" :
281
283
if is_eval and not config .run_eval :
282
284
loader = None
@@ -352,9 +354,9 @@ def get_data_loader(
352
354
353
355
def _get_test_aux_input (
354
356
self ,
355
- ) -> Dict :
357
+ ) -> dict [ str , Any ] :
356
358
d_vector = None
357
- if self .config .use_d_vector_file :
359
+ if self .speaker_manager is not None and self . config .use_d_vector_file :
358
360
d_vector = [self .speaker_manager .embeddings [name ]["embedding" ] for name in self .speaker_manager .embeddings ]
359
361
d_vector = (random .sample (sorted (d_vector ), 1 ),)
360
362
@@ -369,7 +371,7 @@ def _get_test_aux_input(
369
371
}
370
372
return aux_inputs
371
373
372
- def test_run (self , assets : Dict ) -> Tuple [ Dict , Dict ]:
374
+ def test_run (self , assets : dict ) -> tuple [ dict , dict ]:
373
375
"""Generic test run for `vc` models used by `Trainer`.
374
376
375
377
You can override this for a different behaviour.
@@ -378,7 +380,7 @@ def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
378
380
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
379
381
380
382
Returns:
381
- Tuple[Dict, Dict ]: Test figures and audios to be projected to Tensorboard.
383
+ tuple[dict, dict ]: Test figures and audios to be projected to Tensorboard.
382
384
"""
383
385
logger .info ("Synthesizing test sentences." )
384
386
test_audios = {}
@@ -409,7 +411,7 @@ def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
409
411
)
410
412
return test_figures , test_audios
411
413
412
- def on_init_start (self , trainer ) :
414
+ def on_init_start (self , trainer : Trainer ) -> None :
413
415
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
414
416
if self .speaker_manager is not None :
415
417
output_path = os .path .join (trainer .output_path , "speakers.pth" )
0 commit comments