2
2
# SPDX-License-Identifier: Apache-2.0
3
3
4
4
import base64
5
+ import json
5
6
import os
6
7
from io import BytesIO
7
8
16
17
)
17
18
from comps .cores .proto .docarray import LLMParams
18
19
from fastapi import Request
19
- from fastapi .responses import StreamingResponse
20
+ from fastapi .responses import JSONResponse , StreamingResponse
20
21
from PIL import Image
21
22
22
23
MEGA_SERVICE_PORT = int (os .getenv ("MEGA_SERVICE_PORT" , 8888 ))
29
30
30
31
31
32
class MultimodalQnAService (Gateway ):
33
+ asr_port = int (os .getenv ("ASR_SERVICE_PORT" , 3001 ))
34
+ asr_endpoint = os .getenv ("ASR_SERVICE_ENDPOINT" , "http://0.0.0.0:{}/v1/audio/transcriptions" .format (asr_port ))
35
+
32
36
def __init__ (self , host = "0.0.0.0" , port = 8000 ):
33
37
self .host = host
34
38
self .port = port
@@ -73,7 +77,10 @@ def add_remote_service(self):
73
77
# this overrides _handle_message method of Gateway
74
78
def _handle_message (self , messages ):
75
79
images = []
80
+ audios = []
81
+ b64_types = {}
76
82
messages_dicts = []
83
+ decoded_audio_input = ""
77
84
if isinstance (messages , str ):
78
85
prompt = messages
79
86
else :
@@ -87,16 +94,26 @@ def _handle_message(self, messages):
87
94
system_prompt = message ["content" ]
88
95
elif msg_role == "user" :
89
96
if type (message ["content" ]) == list :
97
+ # separate each media type and store accordingly
90
98
text = ""
91
99
text_list = [item ["text" ] for item in message ["content" ] if item ["type" ] == "text" ]
92
100
text += "\n " .join (text_list )
93
101
image_list = [
94
102
item ["image_url" ]["url" ] for item in message ["content" ] if item ["type" ] == "image_url"
95
103
]
96
- if image_list :
97
- messages_dict [msg_role ] = (text , image_list )
98
- else :
104
+ audios = [item ["audio" ] for item in message ["content" ] if item ["type" ] == "audio" ]
105
+ if audios :
106
+ # translate audio to text. From this point forward, audio is treated like text
107
+ decoded_audio_input = self .convert_audio_to_text (audios )
108
+ b64_types ["audio" ] = decoded_audio_input
109
+
110
+ if text and not audios and not image_list :
99
111
messages_dict [msg_role ] = text
112
+ elif audios and not text and not image_list :
113
+ messages_dict [msg_role ] = decoded_audio_input
114
+ else :
115
+ messages_dict [msg_role ] = (text , decoded_audio_input , image_list )
116
+
100
117
else :
101
118
messages_dict [msg_role ] = message ["content" ]
102
119
messages_dicts .append (messages_dict )
@@ -108,55 +125,84 @@ def _handle_message(self, messages):
108
125
109
126
if system_prompt :
110
127
prompt = system_prompt + "\n "
111
- for messages_dict in messages_dicts :
112
- for i , ( role , message ) in enumerate ( messages_dict .items () ):
128
+ for i , messages_dict in enumerate ( messages_dicts ) :
129
+ for role , message in messages_dict .items ():
113
130
if isinstance (message , tuple ):
114
- text , image_list = message
131
+ text , decoded_audio_input , image_list = message
115
132
if i == 0 :
116
133
# do not add role for the very first message.
117
134
# this will be added by llava_server
118
135
if text :
119
136
prompt += text + "\n "
137
+ elif decoded_audio_input :
138
+ prompt += decoded_audio_input + "\n "
120
139
else :
121
140
if text :
122
141
prompt += role .upper () + ": " + text + "\n "
142
+ elif decoded_audio_input :
143
+ prompt += role .upper () + ": " + decoded_audio_input + "\n "
123
144
else :
124
145
prompt += role .upper () + ":"
125
- for img in image_list :
126
- # URL
127
- if img .startswith ("http://" ) or img .startswith ("https://" ):
128
- response = requests .get (img )
129
- image = Image .open (BytesIO (response .content )).convert ("RGBA" )
130
- image_bytes = BytesIO ()
131
- image .save (image_bytes , format = "PNG" )
132
- img_b64_str = base64 .b64encode (image_bytes .getvalue ()).decode ()
133
- # Local Path
134
- elif os .path .exists (img ):
135
- image = Image .open (img ).convert ("RGBA" )
136
- image_bytes = BytesIO ()
137
- image .save (image_bytes , format = "PNG" )
138
- img_b64_str = base64 .b64encode (image_bytes .getvalue ()).decode ()
139
- # Bytes
140
- else :
141
- img_b64_str = img
142
146
143
- images .append (img_b64_str )
144
- else :
147
+ if image_list :
148
+ for img in image_list :
149
+ # URL
150
+ if img .startswith ("http://" ) or img .startswith ("https://" ):
151
+ response = requests .get (img )
152
+ image = Image .open (BytesIO (response .content )).convert ("RGBA" )
153
+ image_bytes = BytesIO ()
154
+ image .save (image_bytes , format = "PNG" )
155
+ img_b64_str = base64 .b64encode (image_bytes .getvalue ()).decode ()
156
+ # Local Path
157
+ elif os .path .exists (img ):
158
+ image = Image .open (img ).convert ("RGBA" )
159
+ image_bytes = BytesIO ()
160
+ image .save (image_bytes , format = "PNG" )
161
+ img_b64_str = base64 .b64encode (image_bytes .getvalue ()).decode ()
162
+ # Bytes
163
+ else :
164
+ img_b64_str = img
165
+
166
+ images .append (img_b64_str )
167
+
168
+ elif isinstance (message , str ):
145
169
if i == 0 :
146
170
# do not add role for the very first message.
147
171
# this will be added by llava_server
148
172
if message :
149
- prompt += role . upper () + ": " + message + "\n "
173
+ prompt += message + "\n "
150
174
else :
151
175
if message :
152
176
prompt += role .upper () + ": " + message + "\n "
153
177
else :
154
178
prompt += role .upper () + ":"
179
+
155
180
if images :
156
- return prompt , images
181
+ b64_types ["image" ] = images
182
+
183
+ # If the query has multiple media types, return all types
184
+ if prompt and b64_types :
185
+ return prompt , b64_types
157
186
else :
158
187
return prompt
159
188
189
+ def convert_audio_to_text (self , audio ):
190
+ # translate audio to text by passing in base64 encoded audio to ASR
191
+ if isinstance (audio , dict ):
192
+ input_dict = {"byte_str" : audio ["audio" ][0 ]}
193
+ else :
194
+ input_dict = {"byte_str" : audio [0 ]}
195
+
196
+ response = requests .post (self .asr_endpoint , data = json .dumps (input_dict ), proxies = {"http" : None })
197
+
198
+ if response .status_code != 200 :
199
+ return JSONResponse (
200
+ status_code = 503 , content = {"message" : "Unable to convert audio to text. {}" .format (response .text )}
201
+ )
202
+
203
+ response = response .json ()
204
+ return response ["query" ]
205
+
160
206
async def handle_request (self , request : Request ):
161
207
data = await request .json ()
162
208
stream_opt = bool (data .get ("stream" , False ))
@@ -165,16 +211,35 @@ async def handle_request(self, request: Request):
165
211
stream_opt = False
166
212
chat_request = ChatCompletionRequest .model_validate (data )
167
213
# Multimodal RAG QnA With Videos has not yet accepts image as input during QnA.
168
- prompt_and_image = self ._handle_message (chat_request .messages )
169
- if isinstance (prompt_and_image , tuple ):
170
- # print(f"This request include image, thus it is a follow-up query. Using lvm megaservice")
171
- prompt , images = prompt_and_image
214
+ num_messages = len (data ["messages" ]) if isinstance (data ["messages" ], list ) else 1
215
+ messages = self ._handle_message (chat_request .messages )
216
+ decoded_audio_input = ""
217
+
218
+ if num_messages > 1 :
219
+ # This is a follow up query, go to LVM
172
220
cur_megaservice = self .lvm_megaservice
173
- initial_inputs = {"prompt" : prompt , "image" : images [0 ]}
221
+ if isinstance (messages , tuple ):
222
+ prompt , b64_types = messages
223
+ if "audio" in b64_types :
224
+ # for metadata storage purposes
225
+ decoded_audio_input = b64_types ["audio" ]
226
+ if "image" in b64_types :
227
+ initial_inputs = {"prompt" : prompt , "image" : b64_types ["image" ][0 ]}
228
+ else :
229
+ initial_inputs = {"prompt" : prompt , "image" : "" }
230
+ else :
231
+ prompt = messages
232
+ initial_inputs = {"prompt" : prompt , "image" : "" }
174
233
else :
175
- # print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice")
176
- prompt = prompt_and_image
234
+ # This is the first query. Ignore image input
177
235
cur_megaservice = self .megaservice
236
+ if isinstance (messages , tuple ):
237
+ prompt , b64_types = messages
238
+ if "audio" in b64_types :
239
+ # for metadata storage purposes
240
+ decoded_audio_input = b64_types ["audio" ]
241
+ else :
242
+ prompt = messages
178
243
initial_inputs = {"text" : prompt }
179
244
180
245
parameters = LLMParams (
@@ -207,18 +272,24 @@ async def handle_request(self, request: Request):
207
272
if "text" in result_dict [last_node ].keys ():
208
273
response = result_dict [last_node ]["text" ]
209
274
else :
210
- # text in not response message
275
+ # text is not in response message
211
276
# something wrong, for example due to empty retrieval results
212
277
if "detail" in result_dict [last_node ].keys ():
213
278
response = result_dict [last_node ]["detail" ]
214
279
else :
215
- response = "The server fail to generate answer to your query!"
280
+ response = "The server failed to generate an answer to your query!"
216
281
if "metadata" in result_dict [last_node ].keys ():
217
282
# from retrieval results
218
283
metadata = result_dict [last_node ]["metadata" ]
284
+ if decoded_audio_input :
285
+ metadata ["audio" ] = decoded_audio_input
219
286
else :
220
287
# follow-up question, no retrieval
221
- metadata = None
288
+ if decoded_audio_input :
289
+ metadata = {"audio" : decoded_audio_input }
290
+ else :
291
+ metadata = None
292
+
222
293
choices = []
223
294
usage = UsageInfo ()
224
295
choices .append (
0 commit comments