12
12
from TTS .tts .layers .delightful_tts .encoders import (
13
13
PhonemeLevelProsodyEncoder ,
14
14
UtteranceLevelProsodyEncoder ,
15
- get_mask_from_lengths ,
16
15
)
17
16
from TTS .tts .layers .delightful_tts .energy_adaptor import EnergyAdaptor
18
17
from TTS .tts .layers .delightful_tts .networks import EmbeddingPadded , positional_encoding
19
18
from TTS .tts .layers .delightful_tts .phoneme_prosody_predictor import PhonemeProsodyPredictor
20
19
from TTS .tts .layers .delightful_tts .pitch_adaptor import PitchAdaptor
21
20
from TTS .tts .layers .delightful_tts .variance_predictor import VariancePredictor
22
21
from TTS .tts .layers .generic .aligner import AlignmentNetwork
23
- from TTS .tts .utils .helpers import generate_path , sequence_mask
22
+ from TTS .tts .utils .helpers import expand_encoder_outputs , generate_attention , sequence_mask
24
23
25
24
logger = logging .getLogger (__name__ )
26
25
@@ -231,42 +230,6 @@ def _init_d_vector(self):
231
230
raise ValueError ("[!] Speaker embedding layer already initialized before d_vector settings." )
232
231
self .embedded_speaker_dim = self .args .d_vector_dim
233
232
234
- @staticmethod
235
- def generate_attn (dr , x_mask , y_mask = None ):
236
- """Generate an attention mask from the linear scale durations.
237
-
238
- Args:
239
- dr (Tensor): Linear scale durations.
240
- x_mask (Tensor): Mask for the input (character) sequence.
241
- y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations
242
- if None. Defaults to None.
243
-
244
- Shapes
245
- - dr: :math:`(B, T_{en})`
246
- - x_mask: :math:`(B, T_{en})`
247
- - y_mask: :math:`(B, T_{de})`
248
- """
249
- # compute decode mask from the durations
250
- if y_mask is None :
251
- y_lengths = dr .sum (1 ).long ()
252
- y_lengths [y_lengths < 1 ] = 1
253
- y_mask = torch .unsqueeze (sequence_mask (y_lengths , None ), 1 ).to (dr .dtype )
254
- attn_mask = torch .unsqueeze (x_mask , - 1 ) * torch .unsqueeze (y_mask , 2 )
255
- attn = generate_path (dr , attn_mask .squeeze (1 )).to (dr .dtype )
256
- return attn
257
-
258
- def _expand_encoder_with_durations (
259
- self ,
260
- o_en : torch .FloatTensor ,
261
- dr : torch .IntTensor ,
262
- x_mask : torch .IntTensor ,
263
- y_lengths : torch .IntTensor ,
264
- ):
265
- y_mask = torch .unsqueeze (sequence_mask (y_lengths , None ), 1 ).to (o_en .dtype )
266
- attn = self .generate_attn (dr , x_mask , y_mask )
267
- o_en_ex = torch .einsum ("kmn, kjm -> kjn" , [attn .float (), o_en ])
268
- return y_mask , o_en_ex , attn .transpose (1 , 2 )
269
-
270
233
def _forward_aligner (
271
234
self ,
272
235
x : torch .FloatTensor ,
@@ -340,8 +303,8 @@ def forward(
340
303
{"d_vectors" : d_vectors , "speaker_ids" : speaker_idx }
341
304
) # pylint: disable=unused-variable
342
305
343
- src_mask = get_mask_from_lengths (src_lens ) # [B, T_src]
344
- mel_mask = get_mask_from_lengths (mel_lens ) # [B, T_mel]
306
+ src_mask = ~ sequence_mask (src_lens ) # [B, T_src]
307
+ mel_mask = ~ sequence_mask (mel_lens ) # [B, T_mel]
345
308
346
309
# Token embeddings
347
310
token_embeddings = self .src_word_emb (tokens ) # [B, T_src, C_hidden]
@@ -420,8 +383,8 @@ def forward(
420
383
encoder_outputs = encoder_outputs .transpose (1 , 2 ) + pitch_emb + energy_emb
421
384
log_duration_prediction = self .duration_predictor (x = encoder_outputs_res .detach (), mask = src_mask )
422
385
423
- mel_pred_mask , encoder_outputs_ex , alignments = self . _expand_encoder_with_durations (
424
- o_en = encoder_outputs , y_lengths = mel_lens , dr = dr , x_mask = ~ src_mask [:, None ]
386
+ encoder_outputs_ex , alignments , mel_pred_mask = expand_encoder_outputs (
387
+ encoder_outputs , y_lengths = mel_lens , duration = dr , x_mask = ~ src_mask [:, None ]
425
388
)
426
389
427
390
x = self .decoder (
@@ -435,7 +398,7 @@ def forward(
435
398
dr = torch .log (dr + 1 )
436
399
437
400
dr_pred = torch .exp (log_duration_prediction ) - 1
438
- alignments_dp = self . generate_attn (dr_pred , src_mask .unsqueeze (1 ), mel_pred_mask ) # [B, T_max, T_max2']
401
+ alignments_dp = generate_attention (dr_pred , src_mask .unsqueeze (1 ), mel_pred_mask ) # [B, T_max, T_max2']
439
402
440
403
return {
441
404
"model_outputs" : x ,
@@ -448,7 +411,7 @@ def forward(
448
411
"p_prosody_pred" : p_prosody_pred ,
449
412
"p_prosody_ref" : p_prosody_ref ,
450
413
"alignments_dp" : alignments_dp ,
451
- "alignments" : alignments , # [B, T_de, T_en]
414
+ "alignments" : alignments . transpose ( 1 , 2 ) , # [B, T_de, T_en]
452
415
"aligner_soft" : aligner_soft ,
453
416
"aligner_mas" : aligner_mas ,
454
417
"aligner_durations" : aligner_durations ,
@@ -469,7 +432,7 @@ def inference(
469
432
pitch_transform : Callable = None ,
470
433
energy_transform : Callable = None ,
471
434
) -> torch .Tensor :
472
- src_mask = get_mask_from_lengths (torch .tensor ([tokens .shape [1 ]], dtype = torch .int64 , device = tokens .device ))
435
+ src_mask = ~ sequence_mask (torch .tensor ([tokens .shape [1 ]], dtype = torch .int64 , device = tokens .device ))
473
436
src_lens = torch .tensor (tokens .shape [1 :2 ]).to (tokens .device ) # pylint: disable=unused-variable
474
437
sid , g , lid , _ = self ._set_cond_input ( # pylint: disable=unused-variable
475
438
{"d_vectors" : d_vectors , "speaker_ids" : speaker_idx }
@@ -536,11 +499,11 @@ def inference(
536
499
duration_pred = torch .round (duration_pred ) # -> [B, T_src]
537
500
mel_lens = duration_pred .sum (1 ) # -> [B,]
538
501
539
- _ , encoder_outputs_ex , alignments = self . _expand_encoder_with_durations (
540
- o_en = encoder_outputs , y_lengths = mel_lens , dr = duration_pred .squeeze (1 ), x_mask = ~ src_mask [:, None ]
502
+ encoder_outputs_ex , alignments , _ = expand_encoder_outputs (
503
+ encoder_outputs , y_lengths = mel_lens , duration = duration_pred .squeeze (1 ), x_mask = ~ src_mask [:, None ]
541
504
)
542
505
543
- mel_mask = get_mask_from_lengths (
506
+ mel_mask = ~ sequence_mask (
544
507
torch .tensor ([encoder_outputs_ex .shape [2 ]], dtype = torch .int64 , device = encoder_outputs_ex .device )
545
508
)
546
509
@@ -557,7 +520,7 @@ def inference(
557
520
x = self .to_mel (x )
558
521
outputs = {
559
522
"model_outputs" : x ,
560
- "alignments" : alignments ,
523
+ "alignments" : alignments . transpose ( 1 , 2 ) ,
561
524
# "pitch": pitch_emb_pred,
562
525
"durations" : duration_pred ,
563
526
"pitch" : pitch_pred ,
0 commit comments