1
+ """
2
+ openai_api.py
3
+ This module implements a FastAPI-based text-to-speech API compatible with OpenAI's interface specification.
4
+
5
+ Main features and improvements:
6
+ - Use app.state to manage global state, ensuring thread safety
7
+ - Add exception handling and unified error responses to improve stability
8
+ - Support multiple voice options and audio formats for greater flexibility
9
+ - Add input validation to ensure the validity of request parameters
10
+ - Support additional OpenAI TTS parameters (e.g., speed) for richer functionality
11
+ - Implement health check endpoint for easy service status monitoring
12
+ - Use asyncio.Lock to manage model access, improving concurrency performance
13
+ - Load and manage speaker embedding files to support personalized speech synthesis
14
+ """
15
+ import io
16
+ import os
17
+ import sys
18
+ import asyncio
19
+ import time
20
+ from typing import Optional , Dict
21
+ from fastapi import FastAPI , HTTPException
22
+ from fastapi .responses import StreamingResponse , JSONResponse
23
+ from pydantic import BaseModel , Field
24
+ import torch
25
+
26
+ # Cross-platform compatibility settings
27
+ if sys .platform == "darwin" :
28
+ os .environ ["PYTORCH_ENABLE_MPS_FALLBACK" ] = "1"
29
+
30
+ # Set working directory and add to system path
31
+ now_dir = os .getcwd ()
32
+ sys .path .append (now_dir )
33
+
34
+ # Import necessary modules
35
+ import ChatTTS
36
+ from tools .audio import pcm_arr_to_mp3_view , pcm_arr_to_ogg_view , pcm_arr_to_wav_view
37
+ from tools .logger import get_logger
38
+ from tools .normalizer .en import normalizer_en_nemo_text
39
+ from tools .normalizer .zh import normalizer_zh_tn
40
+
41
+ # Initialize logger
42
+ logger = get_logger ("Command" )
43
+
44
+ # Initialize FastAPI application
45
+ app = FastAPI ()
46
+
47
+ # Voice mapping table
48
+ # Download stable voices:
49
+ # ModelScope Community: https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker
50
+ # HuggingFace: https://huggingface.co/spaces/taa/ChatTTS_Speaker
51
+ VOICE_MAP = {
52
+ "default" : "1528.pt" ,
53
+ "alloy" : "1384.pt" ,
54
+ "echo" : "2443.pt" ,
55
+ }
56
+
57
+ # Allowed audio formats
58
+ ALLOWED_FORMATS = {"mp3" , "wav" , "ogg" }
59
+
60
+ @app .on_event ("startup" )
61
+ async def startup_event ():
62
+ """Load ChatTTS model and default speaker embedding when the application starts"""
63
+ # Initialize ChatTTS and async lock
64
+ app .state .chat = ChatTTS .Chat (get_logger ("ChatTTS" ))
65
+ app .state .model_lock = asyncio .Lock () # Use async lock instead of thread lock
66
+
67
+ # Register text normalizers
68
+ app .state .chat .normalizer .register ("en" , normalizer_en_nemo_text ())
69
+ app .state .chat .normalizer .register ("zh" , normalizer_zh_tn ())
70
+
71
+ logger .info ("Initializing ChatTTS..." )
72
+ if app .state .chat .load (source = "huggingface" ):
73
+ logger .info ("Model loaded successfully." )
74
+ else :
75
+ logger .error ("Model loading failed, exiting application." )
76
+ raise RuntimeError ("Failed to load ChatTTS model" )
77
+
78
+ # Load default speaker embedding
79
+ # Preload all supported speaker embeddings into memory at startup to avoid repeated loading during runtime
80
+ app .state .spk_emb_map = {}
81
+ for voice , spk_path in VOICE_MAP .items ():
82
+ if os .path .exists (spk_path ):
83
+ app .state .spk_emb_map [voice ] = torch .load (spk_path , map_location = torch .device ("cpu" ))
84
+ logger .info (f"Preloading speaker embedding: { voice } -> { spk_path } " )
85
+ else :
86
+ logger .warning (f"Speaker embedding not found: { spk_path } , skipping preload" )
87
+ app .state .spk_emb = app .state .spk_emb_map .get ("default" ) # Default embedding
88
+
89
+ # Request parameter whitelist
90
+ ALLOWED_PARAMS = {"model" , "input" , "voice" , "response_format" , "speed" , "stream" , "output_format" }
91
+
92
+ class OpenAITTSRequest (BaseModel ):
93
+ """OpenAI TTS request data model"""
94
+ model : str = Field (..., description = "Speech synthesis model, fixed as 'tts-1'" )
95
+ input : str = Field (..., description = "Text content to synthesize" , max_length = 2048 ) # Length limit
96
+ voice : Optional [str ] = Field ("default" , description = "Voice selection, supports: default, alloy, echo" )
97
+ response_format : Optional [str ] = Field ("mp3" , description = "Audio format: mp3, wav, ogg" )
98
+ speed : Optional [float ] = Field (1.0 , ge = 0.5 , le = 2.0 , description = "Speed, range 0.5-2.0" )
99
+ stream : Optional [bool ] = Field (False , description = "Whether to stream" )
100
+ output_format : Optional [str ] = "mp3" # Optional formats: mp3, wav, ogg
101
+ extra_params : Dict [str , Optional [str ]] = Field (default_factory = dict , description = "Unsupported extra parameters" )
102
+
103
+ @classmethod
104
+ def validate_request (cls , request_data : Dict ):
105
+ """Filter unsupported request parameters and unify model value to 'tts-1'"""
106
+ request_data ["model" ] = "tts-1" # Unify model value
107
+ unsupported_params = set (request_data .keys ()) - ALLOWED_PARAMS
108
+ if unsupported_params :
109
+ logger .warning (f"Ignoring unsupported parameters: { unsupported_params } " )
110
+ return {key : request_data [key ] for key in ALLOWED_PARAMS if key in request_data }
111
+
112
+ # Unified error response
113
+ @app .exception_handler (Exception )
114
+ async def custom_exception_handler (request , exc ):
115
+ """Custom exception handler"""
116
+ logger .error (f"Error: { str (exc )} " )
117
+ return JSONResponse (
118
+ status_code = getattr (exc , "status_code" , 500 ),
119
+ content = {"error" : {"message" : str (exc ), "type" : exc .__class__ .__name__ }}
120
+ )
121
+
122
+ @app .post ("/v1/audio/speech" )
123
+ async def generate_voice (request_data : Dict ):
124
+ """Handle speech synthesis request"""
125
+ request_data = OpenAITTSRequest .validate_request (request_data )
126
+ request = OpenAITTSRequest (** request_data )
127
+
128
+ logger .info (f"Received request: text={ request .input } ..., voice={ request .voice } , stream={ request .stream } " )
129
+
130
+ # Validate audio format
131
+ if request .response_format not in ALLOWED_FORMATS :
132
+ raise HTTPException (400 , detail = f"Unsupported audio format: { request .response_format } , supported formats: { ', ' .join (ALLOWED_FORMATS )} " )
133
+
134
+ # Load speaker embedding for the specified voice
135
+ spk_emb = app .state .spk_emb_map .get (request .voice , app .state .spk_emb )
136
+
137
+ # Inference parameters
138
+ params_infer_main = {
139
+ "text" : [request .input ],
140
+ "stream" : request .stream ,
141
+ "lang" : None ,
142
+ "skip_refine_text" : True , # Do not use text refinement
143
+ "refine_text_only" : False ,
144
+ "use_decoder" : True ,
145
+ "audio_seed" : 12345678 ,
146
+ # "text_seed": 87654321, # Random seed for text processing, used to control text refinement
147
+ "do_text_normalization" : True , # Perform text normalization
148
+ "do_homophone_replacement" : True , # Perform homophone replacement
149
+ }
150
+
151
+ # Inference code parameters
152
+ params_infer_code = app .state .chat .InferCodeParams (
153
+ #prompt=f"[speed_{int(request.speed * 10)}]", # Convert to format supported by ChatTTS
154
+ prompt = "[speed_5]" ,
155
+ top_P = 0.5 ,
156
+ top_K = 10 ,
157
+ temperature = 0.1 ,
158
+ repetition_penalty = 1.1 ,
159
+ max_new_token = 2048 ,
160
+ min_new_token = 0 ,
161
+ show_tqdm = True ,
162
+ ensure_non_empty = True ,
163
+ manual_seed = 42 ,
164
+ spk_emb = spk_emb ,
165
+ spk_smp = None ,
166
+ txt_smp = None ,
167
+ stream_batch = 24 ,
168
+ stream_speed = 12000 ,
169
+ pass_first_n_batches = 2
170
+ )
171
+
172
+ try :
173
+ async with app .state .model_lock :
174
+ wavs = app .state .chat .infer (
175
+ text = params_infer_main ["text" ],
176
+ stream = params_infer_main ["stream" ],
177
+ lang = params_infer_main ["lang" ],
178
+ skip_refine_text = params_infer_main ["skip_refine_text" ],
179
+ use_decoder = params_infer_main ["use_decoder" ],
180
+ do_text_normalization = params_infer_main ["do_text_normalization" ],
181
+ do_homophone_replacement = params_infer_main ['do_homophone_replacement' ],
182
+ # params_refine_text = params_refine_text,
183
+ params_infer_code = params_infer_code ,
184
+ )
185
+ except Exception as e :
186
+ raise HTTPException (500 , detail = f"Speech synthesis failed: { str (e )} " )
187
+
188
+ def generate_wav_header (sample_rate = 24000 , bits_per_sample = 16 , channels = 1 ):
189
+ """Generate WAV file header (without data length)"""
190
+ header = bytearray ()
191
+ header .extend (b"RIFF" )
192
+ header .extend (b"\xFF \xFF \xFF \xFF " ) # File size unknown
193
+ header .extend (b"WAVEfmt " )
194
+ header .extend ((16 ).to_bytes (4 , "little" )) # fmt chunk size
195
+ header .extend ((1 ).to_bytes (2 , "little" )) # PCM format
196
+ header .extend ((channels ).to_bytes (2 , "little" )) # Channels
197
+ header .extend ((sample_rate ).to_bytes (4 , "little" )) # Sample rate
198
+ byte_rate = sample_rate * channels * bits_per_sample // 8
199
+ header .extend ((byte_rate ).to_bytes (4 , "little" )) # Byte rate
200
+ block_align = channels * bits_per_sample // 8
201
+ header .extend ((block_align ).to_bytes (2 , "little" )) # Block align
202
+ header .extend ((bits_per_sample ).to_bytes (2 , "little" )) # Bits per sample
203
+ header .extend (b"data" )
204
+ header .extend (b"\xFF \xFF \xFF \xFF " ) # Data size unknown
205
+ return bytes (header )
206
+
207
+ # Handle audio output format
208
+ def convert_audio (wav , format ):
209
+ """Convert audio format"""
210
+ if format == "mp3" :
211
+ return pcm_arr_to_mp3_view (wav )
212
+ elif format == "wav" :
213
+ return pcm_arr_to_wav_view (wav , include_header = False ) # No header in streaming
214
+ elif format == "ogg" :
215
+ return pcm_arr_to_ogg_view (wav )
216
+ return pcm_arr_to_mp3_view (wav )
217
+
218
+ # Return streaming audio data
219
+ if request .stream :
220
+ first_chunk = True
221
+ async def audio_stream ():
222
+ nonlocal first_chunk
223
+ for wav in wavs :
224
+ if request .response_format == "wav" and first_chunk :
225
+ yield generate_wav_header () # Send WAV header
226
+ first_chunk = False
227
+ yield convert_audio (wav , request .response_format )
228
+ media_type = "audio/wav" if request .response_format == "wav" else "audio/mpeg"
229
+ return StreamingResponse (audio_stream (), media_type = media_type )
230
+
231
+ # Return audio file directly
232
+ if request .response_format == 'wav' :
233
+ music_data = pcm_arr_to_wav_view (wavs [0 ])
234
+ else :
235
+ music_data = convert_audio (wavs [0 ], request .response_format )
236
+
237
+ return StreamingResponse (io .BytesIO (music_data ), media_type = "audio/mpeg" , headers = {
238
+ "Content-Disposition" : f"attachment; filename=output.{ request .response_format } "
239
+ })
240
+
241
+ @app .get ("/health" )
242
+ async def health_check ():
243
+ """Health check endpoint"""
244
+ return {"status" : "healthy" , "model_loaded" : bool (app .state .chat )}
0 commit comments