Skip to content

Commit 363efe5

Browse files
authored
Merge pull request oobabooga#6199 from oobabooga/dev
Merge dev branch
2 parents 2f71515 + 8b44d7b commit 363efe5

28 files changed

+464
-230
lines changed

css/html_instruct_style.css

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
.gradio-container .chat .assistant-message {
5151
padding: 20px;
52-
background: var(--color-grey-200);
52+
background: #f4f4f4;
5353
margin-top: 9px !important;
5454
margin-bottom: 12px !important;
5555
border-radius: 7px;

css/main.css

+8-3
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ gradio-app > :first-child {
9595
}
9696

9797
.header_bar {
98-
background-color: #f7f7f7;
98+
background-color: #f4f4f4;
9999
box-shadow: 0 0 3px rgba(22 22 22 / 35%);
100100
margin-bottom: 0;
101101
overflow-x: scroll;
@@ -336,6 +336,11 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
336336
padding-left: 0;
337337
padding-right: 0;
338338
}
339+
340+
.chat {
341+
padding-left: 0;
342+
padding-right: 0;
343+
}
339344
}
340345

341346
.chat {
@@ -391,7 +396,7 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
391396

392397
.chat .message:last-child {
393398
margin-bottom: 0 !important;
394-
padding-bottom: 0 !important;
399+
padding-bottom: 15px !important;
395400
}
396401

397402
.message-body li {
@@ -510,7 +515,7 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
510515
#show-controls {
511516
position: absolute;
512517
height: 100%;
513-
background-color: var(--background-fill-primary);
518+
background-color: transparent;
514519
border: 0 !important;
515520
border-radius: 0;
516521
}

extensions/sd_api_pictures/script.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
'hr_upscaler': 'ESRGAN_4x',
3434
'hr_scale': '1.0',
3535
'seed': -1,
36-
'sampler_name': 'DPM++ 2M Karras',
36+
'sampler_name': 'DPM++ 2M',
3737
'steps': 32,
3838
'cfg_scale': 7,
3939
'textgen_prefix': 'Please provide a detailed and vivid description of [subject]',

extensions/whisper_stt/script.js

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
console.log("Whisper STT script loaded");
2+
3+
let mediaRecorder;
4+
let audioChunks = [];
5+
let isRecording = false;
6+
7+
window.startStopRecording = function() {
8+
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
9+
console.error("getUserMedia not supported on your browser!");
10+
return;
11+
}
12+
13+
if (isRecording == false) {
14+
//console.log("Start recording function called");
15+
navigator.mediaDevices.getUserMedia({ audio: true })
16+
.then(stream => {
17+
//console.log("Got audio stream");
18+
mediaRecorder = new MediaRecorder(stream);
19+
audioChunks = []; // Reset audio chunks
20+
mediaRecorder.start();
21+
//console.log("MediaRecorder started");
22+
recButton.icon;
23+
recordButton.innerHTML = recButton.innerHTML = "Stop";
24+
isRecording = true;
25+
26+
mediaRecorder.addEventListener("dataavailable", event => {
27+
//console.log("Data available event, data size: ", event.data.size);
28+
audioChunks.push(event.data);
29+
});
30+
31+
mediaRecorder.addEventListener("stop", () => {
32+
//console.log("MediaRecorder stopped");
33+
if (audioChunks.length > 0) {
34+
const audioBlob = new Blob(audioChunks, { type: "audio/webm" });
35+
//console.log("Audio blob created, size: ", audioBlob.size);
36+
const reader = new FileReader();
37+
reader.readAsDataURL(audioBlob);
38+
reader.onloadend = function() {
39+
const base64data = reader.result;
40+
//console.log("Audio converted to base64, length: ", base64data.length);
41+
42+
const audioBase64Input = document.querySelector("#audio-base64 textarea");
43+
if (audioBase64Input) {
44+
audioBase64Input.value = base64data;
45+
audioBase64Input.dispatchEvent(new Event("input", { bubbles: true }));
46+
audioBase64Input.dispatchEvent(new Event("change", { bubbles: true }));
47+
//console.log("Updated textarea with base64 data");
48+
} else {
49+
console.error("Could not find audio-base64 textarea");
50+
}
51+
};
52+
} else {
53+
console.error("No audio data recorded for Whisper");
54+
}
55+
});
56+
});
57+
} else {
58+
//console.log("Stopping MediaRecorder");
59+
recordButton.innerHTML = recButton.innerHTML = "Rec.";
60+
isRecording = false;
61+
mediaRecorder.stop();
62+
}
63+
};
64+
65+
const recordButton = gradioApp().querySelector("#record-button");
66+
recordButton.addEventListener("click", window.startStopRecording);
67+
68+
69+
function gradioApp() {
70+
const elems = document.getElementsByTagName("gradio-app");
71+
const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot;
72+
return gradioShadowRoot ? gradioShadowRoot : document;
73+
}
74+
75+
76+
// extra rec button next to generate button
77+
var recButton = recordButton.cloneNode(true);
78+
var generate_button = document.getElementById("Generate");
79+
generate_button.insertAdjacentElement("afterend", recButton);
80+
81+
recButton.style.setProperty("margin-left", "-10px");
82+
recButton.innerHTML = "Rec.";
83+
84+
recButton.addEventListener("click", function() {
85+
recordButton.click();
86+
});

extensions/whisper_stt/script.py

+75-27
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1+
import base64
2+
import gc
3+
import io
4+
from pathlib import Path
5+
16
import gradio as gr
2-
import speech_recognition as sr
7+
import numpy as np
8+
import torch
9+
import whisper
10+
from pydub import AudioSegment
311

412
from modules import shared
513

@@ -8,13 +16,16 @@
816
'value': ["", ""]
917
}
1018

11-
# parameters which can be customized in settings.json of webui
19+
# parameters which can be customized in settings.yaml of webui
1220
params = {
1321
'whipser_language': 'english',
1422
'whipser_model': 'small.en',
1523
'auto_submit': True
1624
}
1725

26+
startup_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27+
WHISPERMODEL = whisper.load_model(params['whipser_model'], device=startup_device)
28+
1829

1930
def chat_input_modifier(text, visible_text, state):
2031
global input_hijack
@@ -25,47 +36,84 @@ def chat_input_modifier(text, visible_text, state):
2536
return text, visible_text
2637

2738

28-
def do_stt(audio, whipser_model, whipser_language):
29-
transcription = ""
30-
r = sr.Recognizer()
31-
32-
# Convert to AudioData
33-
audio_data = sr.AudioData(sample_rate=audio[0], frame_data=audio[1], sample_width=4)
39+
def do_stt(audio, whipser_language):
40+
# use pydub to convert sample_rate and sample_width for whisper input
41+
dubaudio = AudioSegment.from_file(io.BytesIO(audio))
42+
dubaudio = dubaudio.set_channels(1)
43+
dubaudio = dubaudio.set_frame_rate(16000)
44+
dubaudio = dubaudio.set_sample_width(2)
3445

35-
try:
36-
transcription = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
37-
except sr.UnknownValueError:
38-
print("Whisper could not understand audio")
39-
except sr.RequestError as e:
40-
print("Could not request results from Whisper", e)
46+
# same method to get the array as openai whisper repo used from wav file
47+
audio_np = np.frombuffer(dubaudio.raw_data, np.int16).flatten().astype(np.float32) / 32768.0
4148

42-
return transcription
49+
if len(whipser_language) == 0:
50+
result = WHISPERMODEL.transcribe(audio=audio_np)
51+
else:
52+
result = WHISPERMODEL.transcribe(audio=audio_np, language=whipser_language)
53+
return result["text"]
4354

4455

45-
def auto_transcribe(audio, auto_submit, whipser_model, whipser_language):
46-
if audio is None:
56+
def auto_transcribe(audio, auto_submit, whipser_language):
57+
if audio is None or audio == "":
58+
print("Whisper received no audio data")
4759
return "", ""
48-
transcription = do_stt(audio, whipser_model, whipser_language)
60+
audio_bytes = base64.b64decode(audio.split(',')[1])
61+
62+
transcription = do_stt(audio_bytes, whipser_language)
4963
if auto_submit:
5064
input_hijack.update({"state": True, "value": [transcription, transcription]})
65+
return transcription
66+
67+
68+
def reload_whispermodel(whisper_model_name: str, whisper_language: str, device: str):
69+
if len(whisper_model_name) > 0:
70+
global WHISPERMODEL
71+
WHISPERMODEL = None
72+
if torch.cuda.is_available():
73+
torch.cuda.empty_cache()
74+
gc.collect()
5175

52-
return transcription, None
76+
if device != "none":
77+
if device == "cuda":
78+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79+
80+
WHISPERMODEL = whisper.load_model(whisper_model_name, device=device)
81+
params.update({"whipser_model": whisper_model_name})
82+
if ".en" in whisper_model_name:
83+
whisper_language = "english"
84+
audio_update = gr.Audio.update(interactive=True)
85+
else:
86+
audio_update = gr.Audio.update(interactive=False)
87+
return [whisper_model_name, whisper_language, str(device), audio_update]
5388

5489

5590
def ui():
5691
with gr.Accordion("Whisper STT", open=True):
5792
with gr.Row():
58-
audio = gr.Audio(source="microphone")
93+
audio = gr.Textbox(elem_id="audio-base64", visible=False)
94+
record_button = gr.Button("Rec.", elem_id="record-button", elem_classes="custom-button")
5995
with gr.Row():
6096
with gr.Accordion("Settings", open=False):
6197
auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=params['auto_submit'])
62-
whipser_model = gr.Dropdown(label='Whisper Model', value=params['whipser_model'], choices=["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large"])
63-
whipser_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'], choices=["chinese", "german", "spanish", "russian", "korean", "french", "japanese", "portuguese", "turkish", "polish", "catalan", "dutch", "arabic", "swedish", "italian", "indonesian", "hindi", "finnish", "vietnamese", "hebrew", "ukrainian", "greek", "malay", "czech", "romanian", "danish", "hungarian", "tamil", "norwegian", "thai", "urdu", "croatian", "bulgarian", "lithuanian", "latin", "maori", "malayalam", "welsh", "slovak", "telugu", "persian", "latvian", "bengali", "serbian", "azerbaijani", "slovenian", "kannada", "estonian", "macedonian", "breton", "basque", "icelandic", "armenian", "nepali", "mongolian", "bosnian", "kazakh", "albanian", "swahili", "galician", "marathi", "punjabi", "sinhala", "khmer", "shona", "yoruba", "somali", "afrikaans", "occitan", "georgian", "belarusian", "tajik", "sindhi", "gujarati", "amharic", "yiddish", "lao", "uzbek", "faroese", "haitian creole", "pashto", "turkmen", "nynorsk", "maltese", "sanskrit", "luxembourgish", "myanmar", "tibetan", "tagalog", "malagasy", "assamese", "tatar", "hawaiian", "lingala", "hausa", "bashkir", "javanese", "sundanese"])
98+
device_dropd = gr.Dropdown(label='Device', value=str(startup_device), choices=["cuda", "cpu", "none"])
99+
whisper_model_dropd = gr.Dropdown(label='Whisper Model', value=params['whipser_model'], choices=["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large"])
100+
whisper_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'], choices=["english", "chinese", "german", "spanish", "russian", "korean", "french", "japanese", "portuguese", "turkish", "polish", "catalan", "dutch", "arabic", "swedish", "italian", "indonesian", "hindi", "finnish", "vietnamese", "hebrew", "ukrainian", "greek", "malay", "czech", "romanian", "danish", "hungarian", "tamil", "norwegian", "thai", "urdu", "croatian", "bulgarian", "lithuanian", "latin", "maori", "malayalam", "welsh", "slovak", "telugu", "persian", "latvian", "bengali", "serbian", "azerbaijani", "slovenian", "kannada", "estonian", "macedonian", "breton", "basque", "icelandic", "armenian", "nepali", "mongolian", "bosnian", "kazakh", "albanian", "swahili", "galician", "marathi", "punjabi", "sinhala", "khmer", "shona", "yoruba", "somali", "afrikaans", "occitan", "georgian", "belarusian", "tajik", "sindhi", "gujarati", "amharic", "yiddish", "lao", "uzbek", "faroese", "haitian creole", "pashto", "turkmen", "nynorsk", "maltese", "sanskrit", "luxembourgish", "myanmar", "tibetan", "tagalog", "malagasy", "assamese", "tatar", "hawaiian", "lingala", "hausa", "bashkir", "javanese", "sundanese"])
64101

65-
audio.stop_recording(
66-
auto_transcribe, [audio, auto_submit, whipser_model, whipser_language], [shared.gradio['textbox'], audio]).then(
67-
None, auto_submit, None, js="(check) => {if (check) { document.getElementById('Generate').click() }}")
102+
audio.change(
103+
auto_transcribe, [audio, auto_submit, whisper_language], [shared.gradio['textbox']]).then(
104+
None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}")
68105

69-
whipser_model.change(lambda x: params.update({"whipser_model": x}), whipser_model, None)
70-
whipser_language.change(lambda x: params.update({"whipser_language": x}), whipser_language, None)
106+
device_dropd.input(reload_whispermodel, [whisper_model_dropd, whisper_language, device_dropd], [whisper_model_dropd, whisper_language, device_dropd, audio])
107+
whisper_model_dropd.change(reload_whispermodel, [whisper_model_dropd, whisper_language, device_dropd], [whisper_model_dropd, whisper_language, device_dropd, audio])
108+
whisper_language.change(lambda x: params.update({"whipser_language": x}), whisper_language, None)
71109
auto_submit.change(lambda x: params.update({"auto_submit": x}), auto_submit, None)
110+
111+
112+
def custom_js():
113+
"""
114+
Returns custom javascript as a string. It is applied whenever the web UI is
115+
loaded.
116+
:return:
117+
"""
118+
with open(Path(__file__).parent.resolve() / "script.js", "r") as f:
119+
return f.read()

0 commit comments

Comments
 (0)