Skip to content

Commit 58a11ab

Browse files
committed
feat: support vocoders for voice conversion
So far, FreeVC and OpenVoice are both Vits-based and don't have a separate vocoder. kNN-VC needs to be combined with a Hifigan.
1 parent 8848d50 commit 58a11ab

File tree

2 files changed

+34
-21
lines changed

2 files changed

+34
-21
lines changed

TTS/api.py

+27-16
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(
9595
if "tts_models" in model_name:
9696
self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu)
9797
elif "voice_conversion_models" in model_name:
98-
self.load_vc_model_by_name(model_name, gpu=gpu)
98+
self.load_vc_model_by_name(model_name, vocoder_name, gpu=gpu)
9999
# To allow just TTS("xtts")
100100
else:
101101
self.load_model_by_name(model_name, vocoder_name, gpu=gpu)
@@ -157,22 +157,24 @@ def list_models() -> list[str]:
157157

158158
def download_model_by_name(
159159
self, model_name: str, vocoder_name: Optional[str] = None
160-
) -> tuple[Optional[Path], Optional[Path], Optional[Path]]:
160+
) -> tuple[Optional[Path], Optional[Path], Optional[Path], Optional[Path], Optional[Path]]:
161161
model_path, config_path, model_item = self.manager.download_model(model_name)
162162
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
163163
# return model directory if there are multiple files
164164
# we assume that the model knows how to load itself
165-
return None, None, model_path
165+
return None, None, None, None, model_path
166166
if model_item.get("default_vocoder") is None:
167-
return model_path, config_path, None
167+
return model_path, config_path, None, None, None
168168
if vocoder_name is None:
169169
vocoder_name = model_item["default_vocoder"]
170-
vocoder_path, vocoder_config_path, _ = self.manager.download_model(vocoder_name)
171-
# A local vocoder model will take precedence if specified via vocoder_path
172-
if self.vocoder_path is None or self.vocoder_config_path is None:
173-
self.vocoder_path = vocoder_path
174-
self.vocoder_config_path = vocoder_config_path
175-
return model_path, config_path, None
170+
vocoder_path, vocoder_config_path = None, None
171+
# A local vocoder model will take precedence if already specified in __init__
172+
if model_item["model_type"] == "tts_models":
173+
vocoder_path = self.vocoder_path
174+
vocoder_config_path = self.vocoder_config_path
175+
if vocoder_path is None or vocoder_config_path is None:
176+
vocoder_path, vocoder_config_path, _ = self.manager.download_model(vocoder_name)
177+
return model_path, config_path, vocoder_path, vocoder_config_path, None
176178

177179
def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
178180
"""Load one of the 🐸TTS models by name.
@@ -183,17 +185,24 @@ def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None
183185
"""
184186
self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu)
185187

186-
def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False) -> None:
188+
def load_vc_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
187189
"""Load one of the voice conversion models by name.
188190
189191
Args:
190192
model_name (str): Model name to load. You can list models by ```tts.models```.
191193
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
192194
"""
193195
self.model_name = model_name
194-
model_path, config_path, model_dir = self.download_model_by_name(model_name)
196+
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
197+
model_name, vocoder_name
198+
)
195199
self.voice_converter = Synthesizer(
196-
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
200+
vc_checkpoint=model_path,
201+
vc_config=config_path,
202+
vocoder_checkpoint=vocoder_path,
203+
vocoder_config=vocoder_config_path,
204+
model_dir=model_dir,
205+
use_cuda=gpu,
197206
)
198207

199208
def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
@@ -208,7 +217,9 @@ def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] =
208217
self.synthesizer = None
209218
self.model_name = model_name
210219

211-
model_path, config_path, model_dir = self.download_model_by_name(model_name, vocoder_name)
220+
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
221+
model_name, vocoder_name
222+
)
212223

213224
# init synthesizer
214225
# None values are fetch from the model
@@ -217,8 +228,8 @@ def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] =
217228
tts_config_path=config_path,
218229
tts_speakers_file=None,
219230
tts_languages_file=None,
220-
vocoder_checkpoint=self.vocoder_path,
221-
vocoder_config=self.vocoder_config_path,
231+
vocoder_checkpoint=vocoder_path,
232+
vocoder_config=vocoder_config_path,
222233
encoder_checkpoint=self.encoder_path,
223234
encoder_config=self.encoder_config_path,
224235
model_dir=model_dir,

TTS/utils/synthesizer.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,12 @@ def __init__(
9898
if tts_checkpoint:
9999
self._load_tts(self.tts_checkpoint, self.tts_config_path, use_cuda)
100100

101-
if vocoder_checkpoint:
102-
self._load_vocoder(self.vocoder_checkpoint, self.vocoder_config, use_cuda)
103-
104101
if vc_checkpoint and model_dir == "":
105102
self._load_vc(self.vc_checkpoint, self.vc_config, use_cuda)
106103

104+
if vocoder_checkpoint:
105+
self._load_vocoder(self.vocoder_checkpoint, self.vocoder_config, use_cuda)
106+
107107
if model_dir:
108108
if "fairseq" in model_dir:
109109
self._load_fairseq_from_dir(model_dir, use_cuda)
@@ -273,8 +273,10 @@ def save_wav(self, wav: List[int], path: str, pipe_out=None) -> None:
273273
save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate, pipe_out=pipe_out)
274274

275275
def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]:
276-
output_wav = self.vc_model.voice_conversion(source_wav, target_wav)
277-
return output_wav
276+
output = self.vc_model.voice_conversion(source_wav, target_wav)
277+
if self.vocoder_model is not None:
278+
output = self.vocoder_model.inference(output)
279+
return output.squeeze()
278280

279281
def tts(
280282
self,

0 commit comments

Comments
 (0)