Skip to content

Commit a69b7e2

Browse files
committed
fix vocoder train
1 parent fcc054f commit a69b7e2

File tree

12 files changed

+108
-17
lines changed

12 files changed

+108
-17
lines changed

Diff for: cosyvoice/cli/model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ def __init__(self,
299299
self.flow.half()
300300
self.token_hop_len = self.flow.encoder.static_chunk_size
301301
# flow decoder required_cache_size
302-
self.flow_decoder_required_cache_size = self.flow.decoder.estimator.num_decoding_left_chunks * self.flow.decoder.estimator.static_chunk_size
302+
# TODO 基模型训练时没有设置num_decoding_left_chunks,需要重新训一下才能指定flow_decoder_required_cache_size
303+
self.flow_decoder_required_cache_size = 999
303304
# hift cache
304305
self.mel_cache_len = 8
305306
self.source_cache_len = int(self.mel_cache_len * 480)

Diff for: cosyvoice/flow/flow.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def forward(
9191
conds = conds.transpose(1, 2)
9292

9393
mask = (~make_pad_mask(feat_len)).to(h)
94-
# NOTE 这一句应该是不需要的,应该h已经过length_regulator跟feat一样的shape了
94+
# NOTE this is unnecessary, feat/h already same shape
9595
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
9696
loss, _ = self.decoder.compute_loss(
9797
feat.transpose(1, 2).contiguous(),
@@ -117,7 +117,7 @@ def inference(self,
117117
embedding = F.normalize(embedding, dim=1)
118118
embedding = self.spk_embed_affine_layer(embedding)
119119

120-
# concat text and prompt_text
120+
# concat speech token and prompt speech token
121121
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
122122
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
123123
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)

Diff for: cosyvoice/flow/length_regulator.py

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def forward(self, x, ylens=None):
5151

5252
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
5353
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54+
# NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
5455
# x in (B, T, D)
5556
if x2.shape[1] > 40:
5657
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')

Diff for: cosyvoice/hifigan/discriminator.py

+89-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.functional as F
34
try:
4-
from torch.nn.utils.parametrizations import weight_norm
5+
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
56
except ImportError:
6-
from torch.nn.utils import weight_norm
7+
from torch.nn.utils import weight_norm, spectral_norm
78
from typing import List, Optional, Tuple
89
from einops import rearrange
910
from torchaudio.transforms import Spectrogram
1011

12+
LRELU_SLOPE = 0.1
13+
1114

1215
class MultipleDiscriminator(nn.Module):
1316
def __init__(
@@ -141,3 +144,87 @@ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
141144
x += h
142145

143146
return x, fmap
147+
148+
149+
class MultiResSpecDiscriminator(torch.nn.Module):
150+
151+
def __init__(self,
152+
fft_sizes=[1024, 2048, 512],
153+
hop_sizes=[120, 240, 50],
154+
win_lengths=[600, 1200, 240],
155+
window="hann_window"):
156+
157+
super(MultiResSpecDiscriminator, self).__init__()
158+
self.discriminators = nn.ModuleList([
159+
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
160+
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
161+
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
162+
163+
def forward(self, y, y_hat):
164+
y_d_rs = []
165+
y_d_gs = []
166+
fmap_rs = []
167+
fmap_gs = []
168+
for i, d in enumerate(self.discriminators):
169+
y_d_r, fmap_r = d(y)
170+
y_d_g, fmap_g = d(y_hat)
171+
y_d_rs.append(y_d_r)
172+
fmap_rs.append(fmap_r)
173+
y_d_gs.append(y_d_g)
174+
fmap_gs.append(fmap_g)
175+
176+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
177+
178+
179+
def stft(x, fft_size, hop_size, win_length, window):
180+
"""Perform STFT and convert to magnitude spectrogram.
181+
Args:
182+
x (Tensor): Input signal tensor (B, T).
183+
fft_size (int): FFT size.
184+
hop_size (int): Hop size.
185+
win_length (int): Window length.
186+
window (str): Window function type.
187+
Returns:
188+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
189+
"""
190+
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
191+
192+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
193+
return torch.abs(x_stft).transpose(2, 1)
194+
195+
196+
class SpecDiscriminator(nn.Module):
197+
"""docstring for Discriminator."""
198+
199+
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
200+
super(SpecDiscriminator, self).__init__()
201+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
202+
self.fft_size = fft_size
203+
self.shift_size = shift_size
204+
self.win_length = win_length
205+
self.window = getattr(torch, window)(win_length)
206+
self.discriminators = nn.ModuleList([
207+
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
208+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
209+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
210+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
211+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
212+
])
213+
214+
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
215+
216+
def forward(self, y):
217+
218+
fmap = []
219+
y = y.squeeze(1)
220+
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
221+
y = y.unsqueeze(1)
222+
for i, d in enumerate(self.discriminators):
223+
y = d(y)
224+
y = F.leaky_relu(y, LRELU_SLOPE)
225+
fmap.append(y)
226+
227+
y = self.out(y)
228+
fmap.append(y)
229+
230+
return torch.flatten(y, 1, -1), fmap

Diff for: cosyvoice/hifigan/hifigan.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def forward_discriminator(self, batch, device):
5656
with torch.no_grad():
5757
generated_speech, generated_f0 = self.generator(batch, device)
5858
# 2. calculate discriminator outputs
59-
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
59+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach())
6060
# 3. calculate discriminator losses, tpr losses [Optional]
6161
loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
6262
if self.tpr_loss_weight != 0:

Diff for: cosyvoice/llm/llm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, sp
326326
# unistream sequence
327327
else:
328328
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
329-
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i], self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
329+
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
330+
self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
330331
lm_target.append(this_lm_target)
331332
lm_input.append(this_lm_input)
332333
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)

Diff for: cosyvoice/utils/train_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def log_per_save(writer, info_dict):
340340
rank = int(os.environ.get('RANK', 0))
341341
logging.info(
342342
'Epoch {} Step {} CV info lr {} {} rank {}'.format(
343-
epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
343+
epoch, step + 1, lr, rank, ' '.join(['{} {}'.format(k, v) for k, v in loss_dict.items()])))
344344

345345
if writer is not None:
346346
for k in ['epoch', 'lr']:

Diff for: examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
147147
generator: !ref <hift>
148148
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
149149
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
150-
mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
150+
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
151151
mel_spec_transform: [
152152
!ref <mel_spec_transform1>
153153
]

Diff for: examples/libritts/cosyvoice/conf/cosyvoice.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
147147
generator: !ref <hift>
148148
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
149149
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
150-
mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
150+
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
151151
mel_spec_transform: [
152152
!ref <mel_spec_transform1>
153153
]

Diff for: examples/libritts/cosyvoice2/conf/cosyvoice2.yaml

+6-6
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ token_frame_rate: 25
1414
token_mel_ratio: 2
1515

1616
# stream related params
17-
chunk_size: 1 # streaming inference chunk size, in second
18-
num_decoding_left_chunks: 2 # streaming inference flow decoder left chunk size, in second
17+
chunk_size: 2 # streaming inference chunk size, in second
18+
num_decoding_left_chunks: 1 # streaming inference flow decoder left chunk size
1919

2020
# model params
2121
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
@@ -112,19 +112,19 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
112112

113113
# gan related module
114114
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
115-
n_fft: 1024
115+
n_fft: 1920
116116
num_mels: 80
117117
sampling_rate: !ref <sample_rate>
118-
hop_size: 256
119-
win_size: 1024
118+
hop_size: 480
119+
win_size: 1920
120120
fmin: 0
121121
fmax: null
122122
center: False
123123
hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
124124
generator: !ref <hift>
125125
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
126126
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
127-
mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
127+
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
128128
mel_spec_transform: [
129129
!ref <mel_spec_transform1>
130130
]

Diff for: examples/libritts/cosyvoice2/run.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
7171
fi
7272

7373
# train llm
74-
export CUDA_VISIBLE_DEVICES="2,3,4,5,6,7"
74+
export CUDA_VISIBLE_DEVICES="0,1,2,3"
7575
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
7676
job_id=1986
7777
dist_backend="nccl"

Diff for: requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ onnxruntime-gpu==1.18.0; sys_platform == 'linux'
2121
onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'windows'
2222
openai-whisper==20231117
2323
protobuf==4.25
24+
pyarrow==18.1.0
2425
pydantic==2.7.0
2526
pyworld==0.3.4
2627
rich==13.7.1

0 commit comments

Comments
 (0)