Skip to content

Commit 71426ea

Browse files
committed
feat(api): support passing speaker/language id file paths
1 parent 6d7ae99 commit 71426ea

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

TTS/api.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def __init__(
3030
vocoder_config_path: Optional[str] = None,
3131
encoder_path: Optional[str] = None,
3232
encoder_config_path: Optional[str] = None,
33+
speakers_file_path: Optional[str] = None,
34+
language_ids_file_path: Optional[str] = None,
3335
progress_bar: bool = True,
3436
gpu: bool = False,
3537
) -> None:
@@ -68,8 +70,10 @@ def __init__(
6870
vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None.
6971
encoder_path: Path to speaker encoder checkpoint. Default to None.
7072
encoder_config_path: Path to speaker encoder config file. Defaults to None.
71-
progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True.
72-
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
73+
speakers_file_path: JSON file for multi-speaker model. Defaults to None.
74+
language_ids_file_path: JSON file for multilingual model. Defaults to None
75+
progress_bar (bool, optional): Whether to print a progress bar while downloading a model. Defaults to True.
76+
gpu (bool, optional): Enable/disable GPU. Defaults to False. DEPRECATED, use TTS(...).to("cuda")
7377
"""
7478
super().__init__()
7579
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar)
@@ -82,6 +86,8 @@ def __init__(
8286
self.vocoder_config_path = vocoder_config_path
8387
self.encoder_path = encoder_path
8488
self.encoder_config_path = encoder_config_path
89+
self.speakers_file_path = speakers_file_path
90+
self.language_ids_file_path = language_ids_file_path
8591

8692
if gpu:
8793
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
@@ -226,8 +232,8 @@ def load_tts_model_by_path(self, model_path: str, config_path: str, *, gpu: bool
226232
self.synthesizer = Synthesizer(
227233
tts_checkpoint=model_path,
228234
tts_config_path=config_path,
229-
tts_speakers_file=None,
230-
tts_languages_file=None,
235+
tts_speakers_file=self.speakers_file_path,
236+
tts_languages_file=self.language_ids_file_path,
231237
vocoder_checkpoint=self.vocoder_path,
232238
vocoder_config=self.vocoder_config_path,
233239
encoder_checkpoint=self.encoder_path,

0 commit comments

Comments
 (0)