Skip to content

Commit a57e91a

Browse files
committed
feat(vc): allow multiple target audio files
1 parent be06e41 commit a57e91a

File tree

8 files changed

+61
-49
lines changed

8 files changed

+61
-49
lines changed

TTS/api.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tempfile
55
import warnings
66
from pathlib import Path
7-
from typing import Optional
7+
from typing import Optional, Union
88

99
from torch import nn
1010

@@ -388,7 +388,7 @@ def tts_to_file(
388388
def voice_conversion(
389389
self,
390390
source_wav: str,
391-
target_wav: str,
391+
target_wav: Union[str, list[str]],
392392
):
393393
"""Voice conversion with FreeVC. Convert source wav to target speaker.
394394
@@ -406,7 +406,7 @@ def voice_conversion(
406406
def voice_conversion_to_file(
407407
self,
408408
source_wav: str,
409-
target_wav: str,
409+
target_wav: Union[str, list[str]],
410410
file_path: str = "output.wav",
411411
pipe_out=None,
412412
) -> str:
@@ -429,9 +429,10 @@ def voice_conversion_to_file(
429429
def tts_with_vc(
430430
self,
431431
text: str,
432-
language: str = None,
433-
speaker_wav: str = None,
434-
speaker: str = None,
432+
*,
433+
language: Optional[str] = None,
434+
speaker_wav: Union[str, list[str]],
435+
speaker: Optional[str] = None,
435436
split_sentences: bool = True,
436437
):
437438
"""Convert text to speech with voice conversion.
@@ -471,10 +472,11 @@ def tts_with_vc(
471472
def tts_with_vc_to_file(
472473
self,
473474
text: str,
474-
language: str = None,
475-
speaker_wav: str = None,
475+
*,
476+
language: Optional[str] = None,
477+
speaker_wav: Union[str, list[str]],
476478
file_path: str = "output.wav",
477-
speaker: str = None,
479+
speaker: Optional[str] = None,
478480
split_sentences: bool = True,
479481
pipe_out=None,
480482
) -> str:

TTS/bin/synthesize.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -275,13 +275,14 @@ def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
275275
"--source_wav",
276276
type=str,
277277
default=None,
278-
help="Original audio file to convert in the voice of the target_wav",
278+
help="Original audio file to convert into the voice of the target_wav",
279279
)
280280
parser.add_argument(
281281
"--target_wav",
282282
type=str,
283+
nargs="*",
283284
default=None,
284-
help="Target audio file to convert in the voice of the source_wav",
285+
help="Audio file(s) of the target voice into which to convert the source_wav",
285286
)
286287

287288
parser.add_argument(

TTS/utils/synthesizer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,11 @@ def save_wav(self, wav: List[int], path: str, pipe_out=None) -> None:
274274
wav = np.array(wav)
275275
save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate, pipe_out=pipe_out)
276276

277-
def voice_conversion(self, source_wav: str, target_wav: str, **kwargs) -> List[int]:
277+
def voice_conversion(self, source_wav: str, target_wav: Union[str, list[str]], **kwargs) -> List[int]:
278278
start_time = time.time()
279+
280+
if not isinstance(target_wav, list):
281+
target_wav = [target_wav]
279282
output = self.vc_model.voice_conversion(source_wav, target_wav, **kwargs)
280283
if self.vocoder_model is not None:
281284
output = self.vocoder_model.inference(output)

TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def compute_partial_slices(n_samples: int, rate, min_coverage):
115115

116116
return wav_slices, mel_slices
117117

118-
def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
118+
def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75) -> torch.Tensor:
119119
"""
120120
Computes an embedding for a single utterance. The utterance is divided in partial
121121
utterances and an embedding is computed for each. The complete utterance embedding is the

TTS/vc/models/freevc.py

+29-30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Optional, Tuple
2+
from typing import Optional, Tuple, Union
33

44
import librosa
55
import numpy as np
@@ -386,7 +386,7 @@ def forward(
386386
return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
387387

388388
@torch.inference_mode()
389-
def inference(self, c, g=None, mel=None, c_lengths=None):
389+
def inference(self, c, g=None, c_lengths=None):
390390
"""
391391
Inference pass of the model
392392
@@ -401,9 +401,6 @@ def inference(self, c, g=None, mel=None, c_lengths=None):
401401
"""
402402
if c_lengths is None:
403403
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
404-
if not self.use_spk:
405-
g = self.enc_spk.embed_utterance(mel)
406-
g = g.unsqueeze(-1)
407404
z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths)
408405
z = self.flow(z_p, c_mask, g=g, reverse=True)
409406
o = self.dec(z * c_mask, g=g)
@@ -434,45 +431,47 @@ def load_audio(self, wav):
434431
return wav.float()
435432

436433
@torch.inference_mode()
437-
def voice_conversion(self, src, tgt):
434+
def voice_conversion(self, src: Union[str, torch.Tensor], tgt: list[Union[str, torch.Tensor]]):
438435
"""
439436
Voice conversion pass of the model.
440437
441438
Args:
442439
src (str or torch.Tensor): Source utterance.
443-
tgt (str or torch.Tensor): Target utterance.
440+
tgt (list of str or torch.Tensor): Target utterances.
444441
445442
Returns:
446443
torch.Tensor: Output tensor.
447444
"""
448445

449-
wav_tgt = self.load_audio(tgt).cpu().numpy()
450-
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
451-
452-
if self.config.model_args.use_spk:
453-
g_tgt = self.enc_spk_ex.embed_utterance(wav_tgt)[None, :, None]
454-
else:
455-
wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(self.device)
456-
mel_tgt = mel_spectrogram_torch(
457-
wav_tgt,
458-
self.config.audio.filter_length,
459-
self.config.audio.n_mel_channels,
460-
self.config.audio.input_sample_rate,
461-
self.config.audio.hop_length,
462-
self.config.audio.win_length,
463-
self.config.audio.mel_fmin,
464-
self.config.audio.mel_fmax,
465-
)
466446
# src
467447
wav_src = self.load_audio(src)
468448
c = self.extract_wavlm_features(wav_src[None, :])
469449

470-
if self.config.model_args.use_spk:
471-
audio = self.inference(c, g=g_tgt)
472-
else:
473-
audio = self.inference(c, mel=mel_tgt.transpose(1, 2))
474-
audio = audio[0][0].data.cpu().float().numpy()
475-
return audio
450+
# tgt
451+
g_tgts = []
452+
for tg in tgt:
453+
wav_tgt = self.load_audio(tg).cpu().numpy()
454+
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
455+
456+
if self.config.model_args.use_spk:
457+
g_tgts.append(self.enc_spk_ex.embed_utterance(wav_tgt)[None, :, None])
458+
else:
459+
wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(self.device)
460+
mel_tgt = mel_spectrogram_torch(
461+
wav_tgt,
462+
self.config.audio.filter_length,
463+
self.config.audio.n_mel_channels,
464+
self.config.audio.input_sample_rate,
465+
self.config.audio.hop_length,
466+
self.config.audio.win_length,
467+
self.config.audio.mel_fmin,
468+
self.config.audio.mel_fmax,
469+
)
470+
g_tgts.append(self.enc_spk.embed_utterance(mel_tgt.transpose(1, 2)).unsqueeze(-1))
471+
472+
g_tgt = torch.stack(g_tgts).mean(dim=0)
473+
audio = self.inference(c, g=g_tgt)
474+
return audio[0][0].data.cpu().float().numpy()
476475

477476
def eval_step(): ...
478477

TTS/vc/models/knnvc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def inference(self) -> None: ...
172172
def voice_conversion(
173173
self,
174174
source: PathOrTensor,
175-
target: Union[PathOrTensor, list[PathOrTensor]],
175+
target: list[PathOrTensor],
176176
topk: Optional[int] = None,
177177
) -> torch.Tensor:
178178
if not isinstance(target, list):

TTS/vc/models/openvoice.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -296,19 +296,25 @@ def extract_se(self, audio: Union[str, torch.Tensor]) -> tuple[torch.Tensor, tor
296296
return g, spec
297297

298298
@torch.inference_mode()
299-
def voice_conversion(self, src: Union[str, torch.Tensor], tgt: Union[str, torch.Tensor]) -> npt.NDArray[np.float32]:
299+
def voice_conversion(
300+
self, src: Union[str, torch.Tensor], tgt: list[Union[str, torch.Tensor]]
301+
) -> npt.NDArray[np.float32]:
300302
"""
301303
Voice conversion pass of the model.
302304
303305
Args:
304306
src (str or torch.Tensor): Source utterance.
305-
tgt (str or torch.Tensor): Target utterance.
307+
tgt (list of str or torch.Tensor): Target utterance.
306308
307309
Returns:
308310
Output numpy array.
309311
"""
310312
src_se, src_spec = self.extract_se(src)
311-
tgt_se, _ = self.extract_se(tgt)
313+
tgt_ses = []
314+
for tg in tgt:
315+
tgt_se, _ = self.extract_se(tg)
316+
tgt_ses.append(tgt_se)
317+
tgt_se = torch.stack(tgt_ses).mean(dim=0)
312318

313319
aux_input = {"g_src": src_se, "g_tgt": tgt_se}
314320
audio = self.inference(src_spec, aux_input)

tests/zoo_tests/test_models.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ def test_models(tmp_path, model_name, manager):
7171
run_main(main, [*args, "--text", "This is an example.", *speaker_arg, *language_arg])
7272
elif "voice_conversion_models" in model_name:
7373
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
74-
reference_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0032.wav")
75-
run_main(main, [*args, "--source_wav", speaker_wav, "--target_wav", reference_wav])
74+
reference_wav1 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0028.wav")
75+
reference_wav2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0032.wav")
76+
run_main(main, [*args, "--source_wav", speaker_wav, "--target_wav", reference_wav1, reference_wav2])
7677
else:
7778
# only download the model
7879
manager.download_model(model_name)

0 commit comments

Comments
 (0)