|
| 1 | +# Copyright (C) 2024 Intel Corporation |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import asyncio |
| 5 | +import base64 |
| 6 | +import os |
| 7 | + |
| 8 | +from comps import AudioQnAGateway, MicroService, ServiceOrchestrator, ServiceType |
| 9 | + |
| 10 | +MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") |
| 11 | +MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) |
| 12 | + |
| 13 | +WHISPER_SERVER_HOST_IP = os.getenv("WHISPER_SERVER_HOST_IP", "0.0.0.0") |
| 14 | +WHISPER_SERVER_PORT = int(os.getenv("WHISPER_SERVER_PORT", 7066)) |
| 15 | +GPT_SOVITS_SERVER_HOST_IP = os.getenv("GPT_SOVITS_SERVER_HOST_IP", "0.0.0.0") |
| 16 | +GPT_SOVITS_SERVER_PORT = int(os.getenv("GPT_SOVITS_SERVER_PORT", 9088)) |
| 17 | +LLM_SERVER_HOST_IP = os.getenv("LLM_SERVER_HOST_IP", "0.0.0.0") |
| 18 | +LLM_SERVER_PORT = int(os.getenv("LLM_SERVER_PORT", 8888)) |
| 19 | + |
| 20 | + |
| 21 | +def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs): |
| 22 | + print(inputs) |
| 23 | + if self.services[cur_node].service_type == ServiceType.ASR: |
| 24 | + # {'byte_str': 'UklGRigAAABXQVZFZm10IBIAAAABAAEARKwAAIhYAQACABAAAABkYXRhAgAAAAEA'} |
| 25 | + inputs["audio"] = inputs["byte_str"] |
| 26 | + del inputs["byte_str"] |
| 27 | + elif self.services[cur_node].service_type == ServiceType.LLM: |
| 28 | + # convert TGI/vLLM to unified OpenAI /v1/chat/completions format |
| 29 | + next_inputs = {} |
| 30 | + next_inputs["model"] = "tgi" # specifically clarify the fake model to make the format unified |
| 31 | + next_inputs["messages"] = [{"role": "user", "content": inputs["asr_result"]}] |
| 32 | + next_inputs["max_tokens"] = llm_parameters_dict["max_tokens"] |
| 33 | + next_inputs["top_p"] = llm_parameters_dict["top_p"] |
| 34 | + next_inputs["stream"] = inputs["streaming"] # False as default |
| 35 | + next_inputs["frequency_penalty"] = inputs["frequency_penalty"] |
| 36 | + # next_inputs["presence_penalty"] = inputs["presence_penalty"] |
| 37 | + # next_inputs["repetition_penalty"] = inputs["repetition_penalty"] |
| 38 | + next_inputs["temperature"] = inputs["temperature"] |
| 39 | + inputs = next_inputs |
| 40 | + elif self.services[cur_node].service_type == ServiceType.TTS: |
| 41 | + next_inputs = {} |
| 42 | + next_inputs["text"] = inputs["choices"][0]["message"]["content"] |
| 43 | + next_inputs["text_language"] = kwargs["tts_text_language"] if "tts_text_language" in kwargs else "zh" |
| 44 | + inputs = next_inputs |
| 45 | + return inputs |
| 46 | + |
| 47 | + |
| 48 | +def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs): |
| 49 | + if self.services[cur_node].service_type == ServiceType.TTS: |
| 50 | + audio_base64 = base64.b64encode(data).decode("utf-8") |
| 51 | + return {"byte_str": audio_base64} |
| 52 | + return data |
| 53 | + |
| 54 | + |
| 55 | +class AudioQnAService: |
| 56 | + def __init__(self, host="0.0.0.0", port=8000): |
| 57 | + self.host = host |
| 58 | + self.port = port |
| 59 | + ServiceOrchestrator.align_inputs = align_inputs |
| 60 | + ServiceOrchestrator.align_outputs = align_outputs |
| 61 | + self.megaservice = ServiceOrchestrator() |
| 62 | + |
| 63 | + def add_remote_service(self): |
| 64 | + asr = MicroService( |
| 65 | + name="asr", |
| 66 | + host=WHISPER_SERVER_HOST_IP, |
| 67 | + port=WHISPER_SERVER_PORT, |
| 68 | + # endpoint="/v1/audio/transcriptions", |
| 69 | + endpoint="/v1/asr", |
| 70 | + use_remote_service=True, |
| 71 | + service_type=ServiceType.ASR, |
| 72 | + ) |
| 73 | + llm = MicroService( |
| 74 | + name="llm", |
| 75 | + host=LLM_SERVER_HOST_IP, |
| 76 | + port=LLM_SERVER_PORT, |
| 77 | + endpoint="/v1/chat/completions", |
| 78 | + use_remote_service=True, |
| 79 | + service_type=ServiceType.LLM, |
| 80 | + ) |
| 81 | + tts = MicroService( |
| 82 | + name="tts", |
| 83 | + host=GPT_SOVITS_SERVER_HOST_IP, |
| 84 | + port=GPT_SOVITS_SERVER_PORT, |
| 85 | + # endpoint="/v1/audio/speech", |
| 86 | + endpoint="/", |
| 87 | + use_remote_service=True, |
| 88 | + service_type=ServiceType.TTS, |
| 89 | + ) |
| 90 | + self.megaservice.add(asr).add(llm).add(tts) |
| 91 | + self.megaservice.flow_to(asr, llm) |
| 92 | + self.megaservice.flow_to(llm, tts) |
| 93 | + self.gateway = AudioQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) |
| 94 | + |
| 95 | + |
| 96 | +if __name__ == "__main__": |
| 97 | + audioqna = AudioQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) |
| 98 | + audioqna.add_remote_service() |
0 commit comments