Skip to content

Commit 24f796a

Browse files
committed
Merge branch 'main' into dev/lyuxiang.lx
2 parents fd1a951 + 86e26f5 commit 24f796a

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def text_generator():
151151
yield '那份意外的惊喜与深深的祝福'
152152
yield '让我心中充满了甜蜜的快乐,'
153153
yield '笑容如花儿般绽放。'
154-
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator, '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
154+
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
155155
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
156156
```
157157

cosyvoice/llm/llm.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,10 @@ def inference_bistream(
382382
if text_cache.size(1) >= self.mix_ratio[0]:
383383
lm_input_text = text_cache[:, :self.mix_ratio[0]]
384384
logging.info('append {} text token'.format(lm_input_text.size(1)))
385-
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
385+
if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
386+
lm_input = lm_input_text
387+
else:
388+
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
386389
text_cache = text_cache[:, self.mix_ratio[0]:]
387390
else:
388391
logging.info('not enough text token to decode, wait for more')

runtime/python/fastapi/server.py

+4
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,30 @@ def generate_data(model_output):
4444

4545

4646
@app.get("/inference_sft")
47+
@app.post("/inference_sft")
4748
async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
4849
model_output = cosyvoice.inference_sft(tts_text, spk_id)
4950
return StreamingResponse(generate_data(model_output))
5051

5152

5253
@app.get("/inference_zero_shot")
54+
@app.post("/inference_zero_shot")
5355
async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
5456
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
5557
model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
5658
return StreamingResponse(generate_data(model_output))
5759

5860

5961
@app.get("/inference_cross_lingual")
62+
@app.post("/inference_cross_lingual")
6063
async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
6164
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
6265
model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
6366
return StreamingResponse(generate_data(model_output))
6467

6568

6669
@app.get("/inference_instruct")
70+
@app.post("/inference_instruct")
6771
async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
6872
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
6973
return StreamingResponse(generate_data(model_output))

0 commit comments

Comments
 (0)