Skip to content

Commit 950a9a2

Browse files
committed
refactor: move duplicate alignment functions into helpers
1 parent 90087b5 commit 950a9a2

File tree

7 files changed

+114
-170
lines changed

7 files changed

+114
-170
lines changed

TTS/tts/layers/delightful_tts/acoustic_model.py

+12-49
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212
from TTS.tts.layers.delightful_tts.encoders import (
1313
PhonemeLevelProsodyEncoder,
1414
UtteranceLevelProsodyEncoder,
15-
get_mask_from_lengths,
1615
)
1716
from TTS.tts.layers.delightful_tts.energy_adaptor import EnergyAdaptor
1817
from TTS.tts.layers.delightful_tts.networks import EmbeddingPadded, positional_encoding
1918
from TTS.tts.layers.delightful_tts.phoneme_prosody_predictor import PhonemeProsodyPredictor
2019
from TTS.tts.layers.delightful_tts.pitch_adaptor import PitchAdaptor
2120
from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor
2221
from TTS.tts.layers.generic.aligner import AlignmentNetwork
23-
from TTS.tts.utils.helpers import generate_path, sequence_mask
22+
from TTS.tts.utils.helpers import expand_encoder_outputs, generate_attention, sequence_mask
2423

2524
logger = logging.getLogger(__name__)
2625

@@ -231,42 +230,6 @@ def _init_d_vector(self):
231230
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
232231
self.embedded_speaker_dim = self.args.d_vector_dim
233232

234-
@staticmethod
235-
def generate_attn(dr, x_mask, y_mask=None):
236-
"""Generate an attention mask from the linear scale durations.
237-
238-
Args:
239-
dr (Tensor): Linear scale durations.
240-
x_mask (Tensor): Mask for the input (character) sequence.
241-
y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations
242-
if None. Defaults to None.
243-
244-
Shapes
245-
- dr: :math:`(B, T_{en})`
246-
- x_mask: :math:`(B, T_{en})`
247-
- y_mask: :math:`(B, T_{de})`
248-
"""
249-
# compute decode mask from the durations
250-
if y_mask is None:
251-
y_lengths = dr.sum(1).long()
252-
y_lengths[y_lengths < 1] = 1
253-
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
254-
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
255-
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
256-
return attn
257-
258-
def _expand_encoder_with_durations(
259-
self,
260-
o_en: torch.FloatTensor,
261-
dr: torch.IntTensor,
262-
x_mask: torch.IntTensor,
263-
y_lengths: torch.IntTensor,
264-
):
265-
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
266-
attn = self.generate_attn(dr, x_mask, y_mask)
267-
o_en_ex = torch.einsum("kmn, kjm -> kjn", [attn.float(), o_en])
268-
return y_mask, o_en_ex, attn.transpose(1, 2)
269-
270233
def _forward_aligner(
271234
self,
272235
x: torch.FloatTensor,
@@ -340,8 +303,8 @@ def forward(
340303
{"d_vectors": d_vectors, "speaker_ids": speaker_idx}
341304
) # pylint: disable=unused-variable
342305

343-
src_mask = get_mask_from_lengths(src_lens) # [B, T_src]
344-
mel_mask = get_mask_from_lengths(mel_lens) # [B, T_mel]
306+
src_mask = ~sequence_mask(src_lens) # [B, T_src]
307+
mel_mask = ~sequence_mask(mel_lens) # [B, T_mel]
345308

346309
# Token embeddings
347310
token_embeddings = self.src_word_emb(tokens) # [B, T_src, C_hidden]
@@ -420,8 +383,8 @@ def forward(
420383
encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb + energy_emb
421384
log_duration_prediction = self.duration_predictor(x=encoder_outputs_res.detach(), mask=src_mask)
422385

423-
mel_pred_mask, encoder_outputs_ex, alignments = self._expand_encoder_with_durations(
424-
o_en=encoder_outputs, y_lengths=mel_lens, dr=dr, x_mask=~src_mask[:, None]
386+
encoder_outputs_ex, alignments, mel_pred_mask = expand_encoder_outputs(
387+
encoder_outputs, y_lengths=mel_lens, duration=dr, x_mask=~src_mask[:, None]
425388
)
426389

427390
x = self.decoder(
@@ -435,7 +398,7 @@ def forward(
435398
dr = torch.log(dr + 1)
436399

437400
dr_pred = torch.exp(log_duration_prediction) - 1
438-
alignments_dp = self.generate_attn(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2']
401+
alignments_dp = generate_attention(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2']
439402

440403
return {
441404
"model_outputs": x,
@@ -448,7 +411,7 @@ def forward(
448411
"p_prosody_pred": p_prosody_pred,
449412
"p_prosody_ref": p_prosody_ref,
450413
"alignments_dp": alignments_dp,
451-
"alignments": alignments, # [B, T_de, T_en]
414+
"alignments": alignments.transpose(1, 2), # [B, T_de, T_en]
452415
"aligner_soft": aligner_soft,
453416
"aligner_mas": aligner_mas,
454417
"aligner_durations": aligner_durations,
@@ -469,7 +432,7 @@ def inference(
469432
pitch_transform: Callable = None,
470433
energy_transform: Callable = None,
471434
) -> torch.Tensor:
472-
src_mask = get_mask_from_lengths(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device))
435+
src_mask = ~sequence_mask(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device))
473436
src_lens = torch.tensor(tokens.shape[1:2]).to(tokens.device) # pylint: disable=unused-variable
474437
sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable
475438
{"d_vectors": d_vectors, "speaker_ids": speaker_idx}
@@ -536,11 +499,11 @@ def inference(
536499
duration_pred = torch.round(duration_pred) # -> [B, T_src]
537500
mel_lens = duration_pred.sum(1) # -> [B,]
538501

539-
_, encoder_outputs_ex, alignments = self._expand_encoder_with_durations(
540-
o_en=encoder_outputs, y_lengths=mel_lens, dr=duration_pred.squeeze(1), x_mask=~src_mask[:, None]
502+
encoder_outputs_ex, alignments, _ = expand_encoder_outputs(
503+
encoder_outputs, y_lengths=mel_lens, duration=duration_pred.squeeze(1), x_mask=~src_mask[:, None]
541504
)
542505

543-
mel_mask = get_mask_from_lengths(
506+
mel_mask = ~sequence_mask(
544507
torch.tensor([encoder_outputs_ex.shape[2]], dtype=torch.int64, device=encoder_outputs_ex.device)
545508
)
546509

@@ -557,7 +520,7 @@ def inference(
557520
x = self.to_mel(x)
558521
outputs = {
559522
"model_outputs": x,
560-
"alignments": alignments,
523+
"alignments": alignments.transpose(1, 2),
561524
# "pitch": pitch_emb_pred,
562525
"durations": duration_pred,
563526
"pitch": pitch_pred,

TTS/tts/layers/delightful_tts/encoders.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,7 @@
77
from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention
88
from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d
99
from TTS.tts.layers.delightful_tts.networks import STL
10-
11-
12-
def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
13-
batch_size = lengths.shape[0]
14-
max_len = torch.max(lengths).item()
15-
ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1)
16-
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
17-
return mask
10+
from TTS.tts.utils.helpers import sequence_mask
1811

1912

2013
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
@@ -93,7 +86,7 @@ def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor
9386
outputs --- [N, E//2]
9487
"""
9588

96-
mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1)
89+
mel_masks = ~sequence_mask(mel_lens).unsqueeze(1)
9790
x = x.masked_fill(mel_masks, 0)
9891
for conv, norm in zip(self.convs, self.norms):
9992
x = conv(x)
@@ -103,7 +96,7 @@ def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor
10396
for _ in range(2):
10497
mel_lens = stride_lens(mel_lens)
10598

106-
mel_masks = get_mask_from_lengths(mel_lens)
99+
mel_masks = ~sequence_mask(mel_lens)
107100

108101
x = x.masked_fill(mel_masks.unsqueeze(1), 0)
109102
x = x.permute((0, 2, 1))

TTS/tts/models/align_tts.py

+3-33
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from TTS.tts.layers.feed_forward.encoder import Encoder
1414
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
1515
from TTS.tts.models.base_tts import BaseTTS
16-
from TTS.tts.utils.helpers import generate_path, sequence_mask
16+
from TTS.tts.utils.helpers import expand_encoder_outputs, generate_attention, sequence_mask
1717
from TTS.tts.utils.speakers import SpeakerManager
1818
from TTS.tts.utils.text.tokenizer import TTSTokenizer
1919
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@@ -169,35 +169,6 @@ def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask):
169169
dr_mas = torch.sum(attn, -1)
170170
return dr_mas.squeeze(1), log_p
171171

172-
@staticmethod
173-
def generate_attn(dr, x_mask, y_mask=None):
174-
# compute decode mask from the durations
175-
if y_mask is None:
176-
y_lengths = dr.sum(1).long()
177-
y_lengths[y_lengths < 1] = 1
178-
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
179-
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
180-
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
181-
return attn
182-
183-
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
184-
"""Generate attention alignment map from durations and
185-
expand encoder outputs
186-
187-
Examples::
188-
- encoder output: [a,b,c,d]
189-
- durations: [1, 3, 2, 1]
190-
191-
- expanded: [a, b, b, b, c, c, d]
192-
- attention map: [[0, 0, 0, 0, 0, 0, 1],
193-
[0, 0, 0, 0, 1, 1, 0],
194-
[0, 1, 1, 1, 0, 0, 0],
195-
[1, 0, 0, 0, 0, 0, 0]]
196-
"""
197-
attn = self.generate_attn(dr, x_mask, y_mask)
198-
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
199-
return o_en_ex, attn
200-
201172
def format_durations(self, o_dr_log, x_mask):
202173
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
203174
o_dr[o_dr < 1] = 1.0
@@ -243,9 +214,8 @@ def _forward_encoder(self, x, x_lengths, g=None):
243214
return o_en, o_en_dp, x_mask, g
244215

245216
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
246-
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
247217
# expand o_en with durations
248-
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
218+
o_en_ex, attn, y_mask = expand_encoder_outputs(o_en, dr, x_mask, y_lengths)
249219
# positional encoding
250220
if hasattr(self, "pos_encoder"):
251221
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
@@ -282,7 +252,7 @@ def forward(
282252
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
283253
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
284254
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
285-
attn = self.generate_attn(dr_mas, x_mask, y_mask)
255+
attn = generate_attention(dr_mas, x_mask, y_mask)
286256
elif phase == 1:
287257
# train decoder
288258
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)

TTS/tts/models/forward_tts.py

+3-47
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
1515
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
1616
from TTS.tts.models.base_tts import BaseTTS
17-
from TTS.tts.utils.helpers import average_over_durations, generate_path, sequence_mask
17+
from TTS.tts.utils.helpers import average_over_durations, expand_encoder_outputs, generate_attention, sequence_mask
1818
from TTS.tts.utils.speakers import SpeakerManager
1919
from TTS.tts.utils.text.tokenizer import TTSTokenizer
2020
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
@@ -310,49 +310,6 @@ def init_multispeaker(self, config: Coqpit):
310310
self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
311311
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
312312

313-
@staticmethod
314-
def generate_attn(dr, x_mask, y_mask=None):
315-
"""Generate an attention mask from the durations.
316-
317-
Shapes
318-
- dr: :math:`(B, T_{en})`
319-
- x_mask: :math:`(B, T_{en})`
320-
- y_mask: :math:`(B, T_{de})`
321-
"""
322-
# compute decode mask from the durations
323-
if y_mask is None:
324-
y_lengths = dr.sum(1).long()
325-
y_lengths[y_lengths < 1] = 1
326-
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
327-
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
328-
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
329-
return attn
330-
331-
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
332-
"""Generate attention alignment map from durations and
333-
expand encoder outputs
334-
335-
Shapes:
336-
- en: :math:`(B, D_{en}, T_{en})`
337-
- dr: :math:`(B, T_{en})`
338-
- x_mask: :math:`(B, T_{en})`
339-
- y_mask: :math:`(B, T_{de})`
340-
341-
Examples::
342-
343-
encoder output: [a,b,c,d]
344-
durations: [1, 3, 2, 1]
345-
346-
expanded: [a, b, b, b, c, c, d]
347-
attention map: [[0, 0, 0, 0, 0, 0, 1],
348-
[0, 0, 0, 0, 1, 1, 0],
349-
[0, 1, 1, 1, 0, 0, 0],
350-
[1, 0, 0, 0, 0, 0, 0]]
351-
"""
352-
attn = self.generate_attn(dr, x_mask, y_mask)
353-
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2)
354-
return o_en_ex, attn
355-
356313
def format_durations(self, o_dr_log, x_mask):
357314
"""Format predicted durations.
358315
1. Convert to linear scale from log scale
@@ -443,9 +400,8 @@ def _forward_decoder(
443400
Returns:
444401
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
445402
"""
446-
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
447403
# expand o_en with durations
448-
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
404+
o_en_ex, attn, y_mask = expand_encoder_outputs(o_en, dr, x_mask, y_lengths)
449405
# positional encoding
450406
if hasattr(self, "pos_encoder"):
451407
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
@@ -624,7 +580,7 @@ def forward(
624580
o_dr_log = self.duration_predictor(o_en, x_mask)
625581
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
626582
# generate attn mask from predicted durations
627-
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
583+
o_attn = generate_attention(o_dr.squeeze(1), x_mask)
628584
# aligner
629585
o_alignment_dur = None
630586
alignment_soft = None

0 commit comments

Comments
 (0)