Skip to content

Commit 0f19b97

Browse files
authored
Merge pull request #497 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2 parents 5157baf + a4db3db commit 0f19b97

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

cosyvoice/cli/model.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self,
5353
self.tts_speech_token_dict = {}
5454
self.llm_end_dict = {}
5555
self.mel_overlap_dict = {}
56+
self.flow_cache_dict = {}
5657
self.hift_cache_dict = {}
5758

5859
def load(self, llm_model, flow_model, hift_model):
@@ -100,15 +101,18 @@ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uui
100101
self.llm_end_dict[uuid] = True
101102

102103
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
103-
tts_mel = self.flow.inference(token=token.to(self.device),
104-
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
105-
prompt_token=prompt_token.to(self.device),
106-
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
107-
prompt_feat=prompt_feat.to(self.device),
108-
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
109-
embedding=embedding.to(self.device))
104+
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
105+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
106+
prompt_token=prompt_token.to(self.device),
107+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
108+
prompt_feat=prompt_feat.to(self.device),
109+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
110+
embedding=embedding.to(self.device),
111+
flow_cache=self.flow_cache_dict[uuid])
112+
self.flow_cache_dict[uuid] = flow_cache
113+
110114
# mel overlap fade in out
111-
if self.mel_overlap_dict[uuid] is not None:
115+
if self.mel_overlap_dict[uuid].shape[2] != 0:
112116
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
113117
# append hift cache
114118
if self.hift_cache_dict[uuid] is not None:
@@ -145,7 +149,9 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
145149
this_uuid = str(uuid.uuid1())
146150
with self.lock:
147151
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
148-
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
152+
self.hift_cache_dict[this_uuid] = None
153+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
154+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
149155
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
150156
p.start()
151157
if stream is True:
@@ -201,7 +207,9 @@ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat,
201207
this_uuid = str(uuid.uuid1())
202208
with self.lock:
203209
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
204-
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
210+
self.hift_cache_dict[this_uuid] = None
211+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
212+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
205213
if stream is True:
206214
token_hop_len = self.token_min_hop_len
207215
while True:

cosyvoice/flow/flow.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def inference(self,
109109
prompt_token_len,
110110
prompt_feat,
111111
prompt_feat_len,
112-
embedding):
112+
embedding,
113+
flow_cache):
113114
assert token.shape[0] == 1
114115
# xvec projection
115116
embedding = F.normalize(embedding, dim=1)
@@ -133,13 +134,15 @@ def inference(self,
133134
conds = conds.transpose(1, 2)
134135

135136
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
136-
feat = self.decoder(
137+
feat, flow_cache = self.decoder(
137138
mu=h.transpose(1, 2).contiguous(),
138139
mask=mask.unsqueeze(1),
139140
spks=embedding,
140141
cond=conds,
141-
n_timesteps=10
142+
n_timesteps=10,
143+
prompt_len=mel_len1,
144+
flow_cache=flow_cache
142145
)
143146
feat = feat[:, :, mel_len1:]
144147
assert feat.shape[2] == mel_len2
145-
return feat
148+
return feat, flow_cache

cosyvoice/flow/flow_matching.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator:
3232
self.estimator = estimator
3333

3434
@torch.inference_mode()
35-
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
35+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
3636
"""Forward diffusion
3737
3838
Args:
@@ -50,11 +50,21 @@ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
5050
sample: generated mel-spectrogram
5151
shape: (batch_size, n_feats, mel_timesteps)
5252
"""
53+
5354
z = torch.randn_like(mu) * temperature
55+
cache_size = flow_cache.shape[2]
56+
# fix prompt and overlap part mu and z
57+
if cache_size != 0:
58+
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
59+
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
60+
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
61+
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
62+
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
63+
5464
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
5565
if self.t_scheduler == 'cosine':
5666
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
57-
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
67+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
5868

5969
def solve_euler(self, x, t_span, mu, mask, spks, cond):
6070
"""

0 commit comments

Comments
 (0)