Skip to content

Commit fd1a951

Browse files
committedJan 26, 2025
add flow unified training
1 parent aea7520 commit fd1a951

File tree

4 files changed

+38
-26
lines changed

4 files changed

+38
-26
lines changed
 

‎cosyvoice/flow/decoder.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def __init__(
210210
upcast_softmax: bool = False,
211211
cross_attention_norm: Optional[str] = None,
212212
cross_attention_norm_num_groups: int = 32,
213+
qk_norm: Optional[str] = None,
213214
added_kv_proj_dim: Optional[int] = None,
214215
norm_num_groups: Optional[int] = None,
215216
spatial_norm_dim: Optional[int] = None,
@@ -223,7 +224,7 @@ def __init__(
223224
processor: Optional["AttnProcessor2_0"] = None,
224225
out_dim: int = None,
225226
):
226-
super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups,
227+
super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups, qk_norm,
227228
added_kv_proj_dim, norm_num_groups, spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, _from_deprecated_attn_block, processor, out_dim)
228229
processor = CausalAttnProcessor2_0()
229230
self.set_processor(processor)
@@ -505,7 +506,7 @@ def initialize_weights(self):
505506
if m.bias is not None:
506507
nn.init.constant_(m.bias, 0)
507508

508-
def forward(self, x, mask, mu, t, spks=None, cond=None):
509+
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
509510
"""Forward pass of the UNet1DConditional model.
510511
511512
Args:
@@ -540,7 +541,7 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):
540541
mask_down = masks[-1]
541542
x = resnet(x, mask_down, t)
542543
x = rearrange(x, "b c t -> b t c").contiguous()
543-
attn_mask = (torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) == 1)
544+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
544545
attn_mask = mask_to_bias(attn_mask, x.dtype)
545546
for transformer_block in transformer_blocks:
546547
x = transformer_block(
@@ -558,7 +559,7 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):
558559
for resnet, transformer_blocks in self.mid_blocks:
559560
x = resnet(x, mask_mid, t)
560561
x = rearrange(x, "b c t -> b t c").contiguous()
561-
attn_mask = (torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) == 1)
562+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
562563
attn_mask = mask_to_bias(attn_mask, x.dtype)
563564
for transformer_block in transformer_blocks:
564565
x = transformer_block(
@@ -574,7 +575,7 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):
574575
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
575576
x = resnet(x, mask_up, t)
576577
x = rearrange(x, "b c t -> b t c").contiguous()
577-
attn_mask = (torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) == 1)
578+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
578579
attn_mask = mask_to_bias(attn_mask, x.dtype)
579580
for transformer_block in transformer_blocks:
580581
x = transformer_block(
@@ -700,7 +701,7 @@ def __init__(
700701
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
701702
self.initialize_weights()
702703

703-
def forward(self, x, mask, mu, t, spks=None, cond=None):
704+
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
704705
"""Forward pass of the UNet1DConditional model.
705706
706707
Args:
@@ -735,7 +736,10 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):
735736
mask_down = masks[-1]
736737
x, _, _ = resnet(x, mask_down, t)
737738
x = rearrange(x, "b c t -> b t c").contiguous()
738-
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
739+
if streaming is True:
740+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
741+
else:
742+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
739743
attn_mask = mask_to_bias(attn_mask, x.dtype)
740744
for transformer_block in transformer_blocks:
741745
x, _ = transformer_block(
@@ -753,7 +757,10 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):
753757
for resnet, transformer_blocks in self.mid_blocks:
754758
x, _, _ = resnet(x, mask_mid, t)
755759
x = rearrange(x, "b c t -> b t c").contiguous()
756-
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
760+
if streaming is True:
761+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
762+
else:
763+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
757764
attn_mask = mask_to_bias(attn_mask, x.dtype)
758765
for transformer_block in transformer_blocks:
759766
x, _ = transformer_block(
@@ -769,7 +776,10 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):
769776
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
770777
x, _, _ = resnet(x, mask_up, t)
771778
x = rearrange(x, "b c t -> b t c").contiguous()
772-
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
779+
if streaming is True:
780+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
781+
else:
782+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
773783
attn_mask = mask_to_bias(attn_mask, x.dtype)
774784
for transformer_block in transformer_blocks:
775785
x, _ = transformer_block(

‎cosyvoice/flow/flow.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ def forward(
202202
feat_len = batch['speech_feat_len'].to(device)
203203
embedding = batch['embedding'].to(device)
204204

205+
# NOTE unified training, static_chunk_size > 0 or = 0
206+
streaming = True if random.random() < 0.5 else False
207+
205208
# xvec projection
206209
embedding = F.normalize(embedding, dim=1)
207210
embedding = self.spk_embed_affine_layer(embedding)
@@ -211,7 +214,7 @@ def forward(
211214
token = self.input_embedding(torch.clamp(token, min=0)) * mask
212215

213216
# text encode
214-
h, h_lengths = self.encoder(token, token_len)
217+
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
215218
h = self.encoder_proj(h)
216219

217220
# get conditions
@@ -230,7 +233,8 @@ def forward(
230233
mask.unsqueeze(1),
231234
h.transpose(1, 2).contiguous(),
232235
embedding,
233-
cond=conds
236+
cond=conds,
237+
streaming=streaming,
234238
)
235239
return {'loss': loss}
236240

‎cosyvoice/flow/flow_matching.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def forward_estimator(self, x, mask, mu, t, spks, cond):
142142
x.data_ptr()])
143143
return x
144144

145-
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
145+
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
146146
"""Computes diffusion loss
147147
148148
Args:
@@ -179,11 +179,8 @@ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
179179
spks = spks * cfg_mask.view(-1, 1)
180180
cond = cond * cfg_mask.view(-1, 1, 1)
181181

182-
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
182+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
183183
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
184-
if loss.isnan():
185-
print(123)
186-
pred_new = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
187184
return loss, y
188185

189186

‎cosyvoice/transformer/upsample_encoder.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def forward(
255255
xs_lens: torch.Tensor,
256256
decoding_chunk_size: int = 0,
257257
num_decoding_left_chunks: int = -1,
258+
streaming: bool = False,
258259
) -> Tuple[torch.Tensor, torch.Tensor]:
259260
"""Embed positions in tensor.
260261
@@ -286,11 +287,11 @@ def forward(
286287
xs, pos_emb, masks = self.embed(xs, masks)
287288
mask_pad = masks # (B, 1, T/subsample_rate)
288289
chunk_masks = add_optional_chunk_mask(xs, masks,
289-
self.use_dynamic_chunk,
290-
self.use_dynamic_left_chunk,
291-
decoding_chunk_size,
292-
self.static_chunk_size,
293-
num_decoding_left_chunks)
290+
self.use_dynamic_chunk if streaming is True else False,
291+
self.use_dynamic_left_chunk if streaming is True else False,
292+
decoding_chunk_size if streaming is True else 0,
293+
self.static_chunk_size if streaming is True else 0,
294+
num_decoding_left_chunks if streaming is True else -1)
294295
# lookahead + conformer encoder
295296
xs, _ = self.pre_lookahead_layer(xs)
296297
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
@@ -304,11 +305,11 @@ def forward(
304305
xs, pos_emb, masks = self.up_embed(xs, masks)
305306
mask_pad = masks # (B, 1, T/subsample_rate)
306307
chunk_masks = add_optional_chunk_mask(xs, masks,
307-
self.use_dynamic_chunk,
308-
self.use_dynamic_left_chunk,
309-
decoding_chunk_size,
310-
self.static_chunk_size * self.up_layer.stride,
311-
num_decoding_left_chunks)
308+
self.use_dynamic_chunk if streaming is True else False,
309+
self.use_dynamic_left_chunk if streaming is True else False,
310+
decoding_chunk_size if streaming is True else 0,
311+
self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
312+
num_decoding_left_chunks if streaming is True else -1)
312313
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
313314

314315
if self.normalize_before:

0 commit comments

Comments
 (0)