Skip to content

Commit f839c43

Browse files
authored
Moved Audio Query Gateway changes to multimodalqna.py (#29)
* Moved gateway changes to multimodalqna.py Signed-off-by: okhleif-IL <omar.khleif@intel.com> * reverted port changes Signed-off-by: okhleif-IL <omar.khleif@intel.com> * addressed review comments Signed-off-by: okhleif-IL <omar.khleif@intel.com> * reverted print statement Signed-off-by: okhleif-IL <omar.khleif@intel.com> --------- Signed-off-by: okhleif-IL <omar.khleif@intel.com>
1 parent c421e68 commit f839c43

File tree

1 file changed

+109
-38
lines changed

1 file changed

+109
-38
lines changed

MultimodalQnA/multimodalqna.py

Lines changed: 109 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import base64
5+
import json
56
import os
67
from io import BytesIO
78

@@ -16,7 +17,7 @@
1617
)
1718
from comps.cores.proto.docarray import LLMParams
1819
from fastapi import Request
19-
from fastapi.responses import StreamingResponse
20+
from fastapi.responses import JSONResponse, StreamingResponse
2021
from PIL import Image
2122

2223
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
@@ -29,6 +30,9 @@
2930

3031

3132
class MultimodalQnAService(Gateway):
33+
asr_port = int(os.getenv("ASR_SERVICE_PORT", 3001))
34+
asr_endpoint = os.getenv("ASR_SERVICE_ENDPOINT", "http://0.0.0.0:{}/v1/audio/transcriptions".format(asr_port))
35+
3236
def __init__(self, host="0.0.0.0", port=8000):
3337
self.host = host
3438
self.port = port
@@ -73,7 +77,10 @@ def add_remote_service(self):
7377
# this overrides _handle_message method of Gateway
7478
def _handle_message(self, messages):
7579
images = []
80+
audios = []
81+
b64_types = {}
7682
messages_dicts = []
83+
decoded_audio_input = ""
7784
if isinstance(messages, str):
7885
prompt = messages
7986
else:
@@ -87,16 +94,26 @@ def _handle_message(self, messages):
8794
system_prompt = message["content"]
8895
elif msg_role == "user":
8996
if type(message["content"]) == list:
97+
# separate each media type and store accordingly
9098
text = ""
9199
text_list = [item["text"] for item in message["content"] if item["type"] == "text"]
92100
text += "\n".join(text_list)
93101
image_list = [
94102
item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url"
95103
]
96-
if image_list:
97-
messages_dict[msg_role] = (text, image_list)
98-
else:
104+
audios = [item["audio"] for item in message["content"] if item["type"] == "audio"]
105+
if audios:
106+
# translate audio to text. From this point forward, audio is treated like text
107+
decoded_audio_input = self.convert_audio_to_text(audios)
108+
b64_types["audio"] = decoded_audio_input
109+
110+
if text and not audios and not image_list:
99111
messages_dict[msg_role] = text
112+
elif audios and not text and not image_list:
113+
messages_dict[msg_role] = decoded_audio_input
114+
else:
115+
messages_dict[msg_role] = (text, decoded_audio_input, image_list)
116+
100117
else:
101118
messages_dict[msg_role] = message["content"]
102119
messages_dicts.append(messages_dict)
@@ -108,55 +125,84 @@ def _handle_message(self, messages):
108125

109126
if system_prompt:
110127
prompt = system_prompt + "\n"
111-
for messages_dict in messages_dicts:
112-
for i, (role, message) in enumerate(messages_dict.items()):
128+
for i, messages_dict in enumerate(messages_dicts):
129+
for role, message in messages_dict.items():
113130
if isinstance(message, tuple):
114-
text, image_list = message
131+
text, decoded_audio_input, image_list = message
115132
if i == 0:
116133
# do not add role for the very first message.
117134
# this will be added by llava_server
118135
if text:
119136
prompt += text + "\n"
137+
elif decoded_audio_input:
138+
prompt += decoded_audio_input + "\n"
120139
else:
121140
if text:
122141
prompt += role.upper() + ": " + text + "\n"
142+
elif decoded_audio_input:
143+
prompt += role.upper() + ": " + decoded_audio_input + "\n"
123144
else:
124145
prompt += role.upper() + ":"
125-
for img in image_list:
126-
# URL
127-
if img.startswith("http://") or img.startswith("https://"):
128-
response = requests.get(img)
129-
image = Image.open(BytesIO(response.content)).convert("RGBA")
130-
image_bytes = BytesIO()
131-
image.save(image_bytes, format="PNG")
132-
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
133-
# Local Path
134-
elif os.path.exists(img):
135-
image = Image.open(img).convert("RGBA")
136-
image_bytes = BytesIO()
137-
image.save(image_bytes, format="PNG")
138-
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
139-
# Bytes
140-
else:
141-
img_b64_str = img
142146

143-
images.append(img_b64_str)
144-
else:
147+
if image_list:
148+
for img in image_list:
149+
# URL
150+
if img.startswith("http://") or img.startswith("https://"):
151+
response = requests.get(img)
152+
image = Image.open(BytesIO(response.content)).convert("RGBA")
153+
image_bytes = BytesIO()
154+
image.save(image_bytes, format="PNG")
155+
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
156+
# Local Path
157+
elif os.path.exists(img):
158+
image = Image.open(img).convert("RGBA")
159+
image_bytes = BytesIO()
160+
image.save(image_bytes, format="PNG")
161+
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
162+
# Bytes
163+
else:
164+
img_b64_str = img
165+
166+
images.append(img_b64_str)
167+
168+
elif isinstance(message, str):
145169
if i == 0:
146170
# do not add role for the very first message.
147171
# this will be added by llava_server
148172
if message:
149-
prompt += role.upper() + ": " + message + "\n"
173+
prompt += message + "\n"
150174
else:
151175
if message:
152176
prompt += role.upper() + ": " + message + "\n"
153177
else:
154178
prompt += role.upper() + ":"
179+
155180
if images:
156-
return prompt, images
181+
b64_types["image"] = images
182+
183+
# If the query has multiple media types, return all types
184+
if prompt and b64_types:
185+
return prompt, b64_types
157186
else:
158187
return prompt
159188

189+
def convert_audio_to_text(self, audio):
190+
# translate audio to text by passing in base64 encoded audio to ASR
191+
if isinstance(audio, dict):
192+
input_dict = {"byte_str": audio["audio"][0]}
193+
else:
194+
input_dict = {"byte_str": audio[0]}
195+
196+
response = requests.post(self.asr_endpoint, data=json.dumps(input_dict), proxies={"http": None})
197+
198+
if response.status_code != 200:
199+
return JSONResponse(
200+
status_code=503, content={"message": "Unable to convert audio to text. {}".format(response.text)}
201+
)
202+
203+
response = response.json()
204+
return response["query"]
205+
160206
async def handle_request(self, request: Request):
161207
data = await request.json()
162208
stream_opt = bool(data.get("stream", False))
@@ -165,16 +211,35 @@ async def handle_request(self, request: Request):
165211
stream_opt = False
166212
chat_request = ChatCompletionRequest.model_validate(data)
167213
# Multimodal RAG QnA With Videos has not yet accepts image as input during QnA.
168-
prompt_and_image = self._handle_message(chat_request.messages)
169-
if isinstance(prompt_and_image, tuple):
170-
# print(f"This request include image, thus it is a follow-up query. Using lvm megaservice")
171-
prompt, images = prompt_and_image
214+
num_messages = len(data["messages"]) if isinstance(data["messages"], list) else 1
215+
messages = self._handle_message(chat_request.messages)
216+
decoded_audio_input = ""
217+
218+
if num_messages > 1:
219+
# This is a follow up query, go to LVM
172220
cur_megaservice = self.lvm_megaservice
173-
initial_inputs = {"prompt": prompt, "image": images[0]}
221+
if isinstance(messages, tuple):
222+
prompt, b64_types = messages
223+
if "audio" in b64_types:
224+
# for metadata storage purposes
225+
decoded_audio_input = b64_types["audio"]
226+
if "image" in b64_types:
227+
initial_inputs = {"prompt": prompt, "image": b64_types["image"][0]}
228+
else:
229+
initial_inputs = {"prompt": prompt, "image": ""}
230+
else:
231+
prompt = messages
232+
initial_inputs = {"prompt": prompt, "image": ""}
174233
else:
175-
# print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice")
176-
prompt = prompt_and_image
234+
# This is the first query. Ignore image input
177235
cur_megaservice = self.megaservice
236+
if isinstance(messages, tuple):
237+
prompt, b64_types = messages
238+
if "audio" in b64_types:
239+
# for metadata storage purposes
240+
decoded_audio_input = b64_types["audio"]
241+
else:
242+
prompt = messages
178243
initial_inputs = {"text": prompt}
179244

180245
parameters = LLMParams(
@@ -207,18 +272,24 @@ async def handle_request(self, request: Request):
207272
if "text" in result_dict[last_node].keys():
208273
response = result_dict[last_node]["text"]
209274
else:
210-
# text in not response message
275+
# text is not in response message
211276
# something wrong, for example due to empty retrieval results
212277
if "detail" in result_dict[last_node].keys():
213278
response = result_dict[last_node]["detail"]
214279
else:
215-
response = "The server fail to generate answer to your query!"
280+
response = "The server failed to generate an answer to your query!"
216281
if "metadata" in result_dict[last_node].keys():
217282
# from retrieval results
218283
metadata = result_dict[last_node]["metadata"]
284+
if decoded_audio_input:
285+
metadata["audio"] = decoded_audio_input
219286
else:
220287
# follow-up question, no retrieval
221-
metadata = None
288+
if decoded_audio_input:
289+
metadata = {"audio": decoded_audio_input}
290+
else:
291+
metadata = None
292+
222293
choices = []
223294
usage = UsageInfo()
224295
choices.append(

0 commit comments

Comments
 (0)