@@ -95,7 +95,7 @@ def __init__(
95
95
if "tts_models" in model_name :
96
96
self .load_tts_model_by_name (model_name , vocoder_name , gpu = gpu )
97
97
elif "voice_conversion_models" in model_name :
98
- self .load_vc_model_by_name (model_name , gpu = gpu )
98
+ self .load_vc_model_by_name (model_name , vocoder_name , gpu = gpu )
99
99
# To allow just TTS("xtts")
100
100
else :
101
101
self .load_model_by_name (model_name , vocoder_name , gpu = gpu )
@@ -157,22 +157,24 @@ def list_models() -> list[str]:
157
157
158
158
def download_model_by_name (
159
159
self , model_name : str , vocoder_name : Optional [str ] = None
160
- ) -> tuple [Optional [Path ], Optional [Path ], Optional [Path ]]:
160
+ ) -> tuple [Optional [Path ], Optional [Path ], Optional [Path ], Optional [ Path ], Optional [ Path ] ]:
161
161
model_path , config_path , model_item = self .manager .download_model (model_name )
162
162
if "fairseq" in model_name or (model_item is not None and isinstance (model_item ["model_url" ], list )):
163
163
# return model directory if there are multiple files
164
164
# we assume that the model knows how to load itself
165
- return None , None , model_path
165
+ return None , None , None , None , model_path
166
166
if model_item .get ("default_vocoder" ) is None :
167
- return model_path , config_path , None
167
+ return model_path , config_path , None , None , None
168
168
if vocoder_name is None :
169
169
vocoder_name = model_item ["default_vocoder" ]
170
- vocoder_path , vocoder_config_path , _ = self .manager .download_model (vocoder_name )
171
- # A local vocoder model will take precedence if specified via vocoder_path
172
- if self .vocoder_path is None or self .vocoder_config_path is None :
173
- self .vocoder_path = vocoder_path
174
- self .vocoder_config_path = vocoder_config_path
175
- return model_path , config_path , None
170
+ vocoder_path , vocoder_config_path = None , None
171
+ # A local vocoder model will take precedence if already specified in __init__
172
+ if model_item ["model_type" ] == "tts_models" :
173
+ vocoder_path = self .vocoder_path
174
+ vocoder_config_path = self .vocoder_config_path
175
+ if vocoder_path is None or vocoder_config_path is None :
176
+ vocoder_path , vocoder_config_path , _ = self .manager .download_model (vocoder_name )
177
+ return model_path , config_path , vocoder_path , vocoder_config_path , None
176
178
177
179
def load_model_by_name (self , model_name : str , vocoder_name : Optional [str ] = None , * , gpu : bool = False ) -> None :
178
180
"""Load one of the 🐸TTS models by name.
@@ -183,17 +185,24 @@ def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None
183
185
"""
184
186
self .load_tts_model_by_name (model_name , vocoder_name , gpu = gpu )
185
187
186
- def load_vc_model_by_name (self , model_name : str , * , gpu : bool = False ) -> None :
188
+ def load_vc_model_by_name (self , model_name : str , vocoder_name : Optional [ str ] = None , * , gpu : bool = False ) -> None :
187
189
"""Load one of the voice conversion models by name.
188
190
189
191
Args:
190
192
model_name (str): Model name to load. You can list models by ```tts.models```.
191
193
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
192
194
"""
193
195
self .model_name = model_name
194
- model_path , config_path , model_dir = self .download_model_by_name (model_name )
196
+ model_path , config_path , vocoder_path , vocoder_config_path , model_dir = self .download_model_by_name (
197
+ model_name , vocoder_name
198
+ )
195
199
self .voice_converter = Synthesizer (
196
- vc_checkpoint = model_path , vc_config = config_path , model_dir = model_dir , use_cuda = gpu
200
+ vc_checkpoint = model_path ,
201
+ vc_config = config_path ,
202
+ vocoder_checkpoint = vocoder_path ,
203
+ vocoder_config = vocoder_config_path ,
204
+ model_dir = model_dir ,
205
+ use_cuda = gpu ,
197
206
)
198
207
199
208
def load_tts_model_by_name (self , model_name : str , vocoder_name : Optional [str ] = None , * , gpu : bool = False ) -> None :
@@ -208,7 +217,9 @@ def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] =
208
217
self .synthesizer = None
209
218
self .model_name = model_name
210
219
211
- model_path , config_path , model_dir = self .download_model_by_name (model_name , vocoder_name )
220
+ model_path , config_path , vocoder_path , vocoder_config_path , model_dir = self .download_model_by_name (
221
+ model_name , vocoder_name
222
+ )
212
223
213
224
# init synthesizer
214
225
# None values are fetch from the model
@@ -217,8 +228,8 @@ def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] =
217
228
tts_config_path = config_path ,
218
229
tts_speakers_file = None ,
219
230
tts_languages_file = None ,
220
- vocoder_checkpoint = self . vocoder_path ,
221
- vocoder_config = self . vocoder_config_path ,
231
+ vocoder_checkpoint = vocoder_path ,
232
+ vocoder_config = vocoder_config_path ,
222
233
encoder_checkpoint = self .encoder_path ,
223
234
encoder_config = self .encoder_config_path ,
224
235
model_dir = model_dir ,
0 commit comments