1
1
import logging
2
- from typing import Optional , Tuple
2
+ from typing import Optional , Tuple , Union
3
3
4
4
import librosa
5
5
import numpy as np
@@ -386,7 +386,7 @@ def forward(
386
386
return o , ids_slice , spec_mask , (z , z_p , m_p , logs_p , m_q , logs_q )
387
387
388
388
@torch .inference_mode ()
389
- def inference (self , c , g = None , mel = None , c_lengths = None ):
389
+ def inference (self , c , g = None , c_lengths = None ):
390
390
"""
391
391
Inference pass of the model
392
392
@@ -401,9 +401,6 @@ def inference(self, c, g=None, mel=None, c_lengths=None):
401
401
"""
402
402
if c_lengths is None :
403
403
c_lengths = (torch .ones (c .size (0 )) * c .size (- 1 )).to (c .device )
404
- if not self .use_spk :
405
- g = self .enc_spk .embed_utterance (mel )
406
- g = g .unsqueeze (- 1 )
407
404
z_p , m_p , logs_p , c_mask = self .enc_p (c , c_lengths )
408
405
z = self .flow (z_p , c_mask , g = g , reverse = True )
409
406
o = self .dec (z * c_mask , g = g )
@@ -434,45 +431,47 @@ def load_audio(self, wav):
434
431
return wav .float ()
435
432
436
433
@torch .inference_mode ()
437
- def voice_conversion (self , src , tgt ):
434
+ def voice_conversion (self , src : Union [ str , torch . Tensor ], tgt : list [ Union [ str , torch . Tensor ]] ):
438
435
"""
439
436
Voice conversion pass of the model.
440
437
441
438
Args:
442
439
src (str or torch.Tensor): Source utterance.
443
- tgt (str or torch.Tensor): Target utterance .
440
+ tgt (list of str or torch.Tensor): Target utterances .
444
441
445
442
Returns:
446
443
torch.Tensor: Output tensor.
447
444
"""
448
445
449
- wav_tgt = self .load_audio (tgt ).cpu ().numpy ()
450
- wav_tgt , _ = librosa .effects .trim (wav_tgt , top_db = 20 )
451
-
452
- if self .config .model_args .use_spk :
453
- g_tgt = self .enc_spk_ex .embed_utterance (wav_tgt )[None , :, None ]
454
- else :
455
- wav_tgt = torch .from_numpy (wav_tgt ).unsqueeze (0 ).to (self .device )
456
- mel_tgt = mel_spectrogram_torch (
457
- wav_tgt ,
458
- self .config .audio .filter_length ,
459
- self .config .audio .n_mel_channels ,
460
- self .config .audio .input_sample_rate ,
461
- self .config .audio .hop_length ,
462
- self .config .audio .win_length ,
463
- self .config .audio .mel_fmin ,
464
- self .config .audio .mel_fmax ,
465
- )
466
446
# src
467
447
wav_src = self .load_audio (src )
468
448
c = self .extract_wavlm_features (wav_src [None , :])
469
449
470
- if self .config .model_args .use_spk :
471
- audio = self .inference (c , g = g_tgt )
472
- else :
473
- audio = self .inference (c , mel = mel_tgt .transpose (1 , 2 ))
474
- audio = audio [0 ][0 ].data .cpu ().float ().numpy ()
475
- return audio
450
+ # tgt
451
+ g_tgts = []
452
+ for tg in tgt :
453
+ wav_tgt = self .load_audio (tg ).cpu ().numpy ()
454
+ wav_tgt , _ = librosa .effects .trim (wav_tgt , top_db = 20 )
455
+
456
+ if self .config .model_args .use_spk :
457
+ g_tgts .append (self .enc_spk_ex .embed_utterance (wav_tgt )[None , :, None ])
458
+ else :
459
+ wav_tgt = torch .from_numpy (wav_tgt ).unsqueeze (0 ).to (self .device )
460
+ mel_tgt = mel_spectrogram_torch (
461
+ wav_tgt ,
462
+ self .config .audio .filter_length ,
463
+ self .config .audio .n_mel_channels ,
464
+ self .config .audio .input_sample_rate ,
465
+ self .config .audio .hop_length ,
466
+ self .config .audio .win_length ,
467
+ self .config .audio .mel_fmin ,
468
+ self .config .audio .mel_fmax ,
469
+ )
470
+ g_tgts .append (self .enc_spk .embed_utterance (mel_tgt .transpose (1 , 2 )).unsqueeze (- 1 ))
471
+
472
+ g_tgt = torch .stack (g_tgts ).mean (dim = 0 )
473
+ audio = self .inference (c , g = g_tgt )
474
+ return audio [0 ][0 ].data .cpu ().float ().numpy ()
476
475
477
476
def eval_step (): ...
478
477
0 commit comments