Skip to content

Commit e06fd5b

Browse files
authored
feat(webui): support stream infer (#380)
1 parent d57dde1 commit e06fd5b

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

examples/web/webui.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,40 @@ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_
7878

7979
return [(sample_rate, audio_data), text_data]
8080

81+
def generate_audio_stream(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag):
82+
83+
torch.manual_seed(audio_seed_input)
84+
rand_spk = chat.sample_random_speaker()
85+
params_infer_code = {
86+
'spk_emb': rand_spk,
87+
'temperature': temperature,
88+
'top_P': top_P,
89+
'top_K': top_K,
90+
}
91+
params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
92+
93+
torch.manual_seed(text_seed_input)
94+
95+
96+
wavs_gen = chat.infer(text,
97+
skip_refine_text=True,
98+
params_refine_text=params_refine_text,
99+
params_infer_code=params_infer_code,
100+
stream=True)
101+
102+
for gen in wavs_gen:
103+
wavs = [np.array([[]])]
104+
wavs[0] = np.hstack([wavs[0], np.array(gen[0])])
105+
audio = wavs[0][0]
106+
107+
max_audio = np.abs(audio).max() # 简单防止16bit爆音
108+
if max_audio > 1:
109+
audio /= max_audio
110+
111+
yield 24000,(audio * 32768).astype(np.int16)
112+
113+
114+
81115

82116
def main():
83117

@@ -103,9 +137,10 @@ def main():
103137
generate_text_seed = gr.Button("\U0001F3B2")
104138

105139
generate_button = gr.Button("Generate")
140+
stream_generate_button = gr.Button("Streaming Generate")
106141

107142
text_output = gr.Textbox(label="Output Text", interactive=False)
108-
audio_output = gr.Audio(label="Output Audio")
143+
audio_output = gr.Audio(label="Output Audio",value=None,streaming=True,autoplay=True,interactive=False,show_label=True)
109144

110145
# 使用Gradio的回调功能来更新数值输入框
111146
voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input)
@@ -121,6 +156,10 @@ def main():
121156
generate_button.click(generate_audio,
122157
inputs=[text_input, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox],
123158
outputs=[audio_output, text_output])
159+
160+
stream_generate_button.click(generate_audio_stream,
161+
inputs=[text_input, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox],
162+
outputs=[audio_output])
124163

125164
gr.Examples(
126165
examples=[

0 commit comments

Comments
 (0)