1
1
"""
2
2
openai_api.py
3
- 这个模块实现了一个基于 FastAPI 的语音合成 API,兼容 OpenAI 的接口规范。
4
3
This module implements a FastAPI-based text-to-speech API compatible with OpenAI's interface specification.
5
4
6
- 主要功能和改进:
7
5
Main features and improvements:
8
- - 使用 app.state 管理全局状态,确保线程安全
9
- Use app.state to manage global state, ensuring thread safety
10
- - 添加异常处理和统一的错误响应,提升稳定性
11
- Add exception handling and unified error responses to improve stability
12
- - 支持多种语音选择和多种音频格式,增加灵活性
13
- Support multiple voice options and audio formats for greater flexibility
14
- - 增加输入验证,确保请求参数的合法性
15
- Add input validation to ensure the validity of request parameters
16
- - 支持更多 OpenAI TTS 参数(如 speed),提供更丰富的功能
17
- Support additional OpenAI TTS parameters (e.g., speed) for richer functionality
18
- - 实现了健康检查端点,便于监控服务状态
19
- Implement health check endpoint for easy service status monitoring
20
- - 使用异步锁(asyncio.Lock)管理模型访问,提升并发性能
21
- Use asyncio.Lock to manage model access, improving concurrency performance
22
- - 加载和管理说话者嵌入文件,支持个性化语音合成
23
- Load and manage speaker embedding files to support personalized speech synthesis
6
+ - Use app.state to manage global state, ensuring thread safety
7
+ - Add exception handling and unified error responses to improve stability
8
+ - Support multiple voice options and audio formats for greater flexibility
9
+ - Add input validation to ensure the validity of request parameters
10
+ - Support additional OpenAI TTS parameters (e.g., speed) for richer functionality
11
+ - Implement health check endpoint for easy service status monitoring
12
+ - Use asyncio.Lock to manage model access, improving concurrency performance
13
+ - Load and manage speaker embedding files to support personalized speech synthesis
24
14
"""
25
15
import io
26
16
import os
33
23
from pydantic import BaseModel , Field
34
24
import torch
35
25
36
- # 跨平台兼容性设置 / Cross-platform compatibility settings
26
+ # Cross-platform compatibility settings
37
27
if sys .platform == "darwin" :
38
28
os .environ ["PYTORCH_ENABLE_MPS_FALLBACK" ] = "1"
39
29
40
- # 设置工作目录并添加到系统路径 / Set working directory and add to system path
30
+ # Set working directory and add to system path
41
31
now_dir = os .getcwd ()
42
32
sys .path .append (now_dir )
43
33
44
- # 导入必要的模块 / Import necessary modules
34
+ # Import necessary modules
45
35
import ChatTTS
46
36
from tools .audio import pcm_arr_to_mp3_view , pcm_arr_to_ogg_view , pcm_arr_to_wav_view
47
37
from tools .logger import get_logger
48
38
from tools .normalizer .en import normalizer_en_nemo_text
49
39
from tools .normalizer .zh import normalizer_zh_tn
50
40
51
- # 初始化日志记录器 / Initialize logger
41
+ # Initialize logger
52
42
logger = get_logger ("Command" )
53
43
54
- # 初始化 FastAPI 应用 / Initialize FastAPI application
44
+ # Initialize FastAPI application
55
45
app = FastAPI ()
56
46
57
- # 语音映射表 / Voice mapping table
58
- # 下载稳定音色 / Download stable voices:
59
- # 魔塔社区 https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker
60
- # HuggingFace https://huggingface.co/spaces/taa/ChatTTS_Speaker
47
+ # Voice mapping table
48
+ # Download stable voices:
49
+ # ModelScope Community: https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker
50
+ # HuggingFace: https://huggingface.co/spaces/taa/ChatTTS_Speaker
61
51
VOICE_MAP = {
62
52
"default" : "1528.pt" ,
63
53
"alloy" : "1384.pt" ,
64
54
"echo" : "2443.pt" ,
65
- "speak_1000" : "1000.pt" ,
66
- "speak_1509" : "1509.pt" ,
67
- "speak_1996" : "1996.pt" ,
68
- "speak_2115" : "2115.pt" ,
69
- "speak_2166" : "2166.pt" ,
70
- "speak_2218" : "2218.pt"
71
55
}
72
56
73
- # 允许的音频格式 / Allowed audio formats
57
+ # Allowed audio formats
74
58
ALLOWED_FORMATS = {"mp3" , "wav" , "ogg" }
75
59
76
60
@app .on_event ("startup" )
77
61
async def startup_event ():
78
- """应用启动时加载 ChatTTS 模型及默认说话者嵌入
79
- Load ChatTTS model and default speaker embedding when the application starts"""
80
- # 初始化 ChatTTS 和异步锁 / Initialize ChatTTS and async lock
62
+ """Load ChatTTS model and default speaker embedding when the application starts"""
63
+ # Initialize ChatTTS and async lock
81
64
app .state .chat = ChatTTS .Chat (get_logger ("ChatTTS" ))
82
- app .state .model_lock = asyncio .Lock () # 使用异步锁替代线程锁 / Use async lock instead of thread lock
65
+ app .state .model_lock = asyncio .Lock () # Use async lock instead of thread lock
83
66
84
- # 注册文本规范化器 / Register text normalizers
67
+ # Register text normalizers
85
68
app .state .chat .normalizer .register ("en" , normalizer_en_nemo_text ())
86
69
app .state .chat .normalizer .register ("zh" , normalizer_zh_tn ())
87
70
88
- logger .info ("正在初始化 ChatTTS... / Initializing ChatTTS..." )
71
+ logger .info ("Initializing ChatTTS..." )
89
72
if app .state .chat .load (source = "huggingface" ):
90
- logger .info ("模型加载成功。 / Model loaded successfully." )
73
+ logger .info ("Model loaded successfully." )
91
74
else :
92
- logger .error ("模型加载失败,退出应用。 / Model loading failed, exiting application." )
75
+ logger .error ("Model loading failed, exiting application." )
93
76
raise RuntimeError ("Failed to load ChatTTS model" )
94
77
95
- # 加载默认说话者嵌入 / Load default speaker embedding
96
- # 在启动时预加载所有支持的说话者嵌入到内存中,避免运行时重复加载
78
+ # Load default speaker embedding
79
+ # Preload all supported speaker embeddings into memory at startup to avoid repeated loading during runtime
97
80
app .state .spk_emb_map = {}
98
81
for voice , spk_path in VOICE_MAP .items ():
99
82
if os .path .exists (spk_path ):
100
83
app .state .spk_emb_map [voice ] = torch .load (spk_path , map_location = torch .device ("cpu" ))
101
- logger .info (f"预加载说话者嵌入 : { voice } -> { spk_path } " )
84
+ logger .info (f"Preloading speaker embedding : { voice } -> { spk_path } " )
102
85
else :
103
- logger .warning (f"未找到 { spk_path } ,跳过预加载 " )
104
- app .state .spk_emb = app .state .spk_emb_map .get ("default" ) # 默认嵌入
86
+ logger .warning (f"Speaker embedding not found: { spk_path } , skipping preload " )
87
+ app .state .spk_emb = app .state .spk_emb_map .get ("default" ) # Default embedding
105
88
106
- # 请求参数白名单 / Request parameter whitelist
89
+ # Request parameter whitelist
107
90
ALLOWED_PARAMS = {"model" , "input" , "voice" , "response_format" , "speed" , "stream" , "output_format" }
108
91
109
92
class OpenAITTSRequest (BaseModel ):
110
- """OpenAI TTS 请求数据模型 / OpenAI TTS request data model"""
111
- model : str = Field (..., description = "语音合成模型,固定为 'tts-1' / Speech synthesis model, fixed as 'tts-1'" )
112
- input : str = Field (..., description = "待合成的文本内容 / Text content to synthesize" , max_length = 2048 ) # 限制长度 / Length limit
113
- voice : Optional [str ] = Field ("default" , description = "语音选择,支持: default, alloy, echo / Voice selection, supports: default, alloy, echo" )
114
- response_format : Optional [str ] = Field ("mp3" , description = "音频格式: mp3, wav, ogg / Audio format: mp3, wav, ogg" )
115
- speed : Optional [float ] = Field (1.0 , ge = 0.5 , le = 2.0 , description = "语速,范围 0.5-2.0 / Speed, range 0.5-2.0" )
116
- stream : Optional [bool ] = Field (False , description = "是否流式传输 / Whether to stream" )
117
- output_format : Optional [str ] = "mp3" # 可选格式:mp3, wav, ogg / Optional formats: mp3, wav, ogg
118
- extra_params : Dict [str , Optional [str ]] = Field (default_factory = dict , description = "不支持的额外参数 / Unsupported extra parameters" )
93
+ """OpenAI TTS request data model"""
94
+ model : str = Field (..., description = "Speech synthesis model, fixed as 'tts-1'" )
95
+ input : str = Field (..., description = "Text content to synthesize" , max_length = 2048 ) # Length limit
96
+ voice : Optional [str ] = Field ("default" , description = "Voice selection, supports: default, alloy, echo" )
97
+ response_format : Optional [str ] = Field ("mp3" , description = "Audio format: mp3, wav, ogg" )
98
+ speed : Optional [float ] = Field (1.0 , ge = 0.5 , le = 2.0 , description = "Speed, range 0.5-2.0" )
99
+ stream : Optional [bool ] = Field (False , description = "Whether to stream" )
100
+ output_format : Optional [str ] = "mp3" # Optional formats: mp3, wav, ogg
101
+ extra_params : Dict [str , Optional [str ]] = Field (default_factory = dict , description = "Unsupported extra parameters" )
119
102
120
103
@classmethod
121
104
def validate_request (cls , request_data : Dict ):
122
- """过滤不支持的请求参数,并统一 model 值为 'tts-1'
123
- Filter unsupported request parameters and unify model value to 'tts-1'"""
124
- request_data ["model" ] = "tts-1" # 统一 model 值 / Unify model value
105
+ """Filter unsupported request parameters and unify model value to 'tts-1'"""
106
+ request_data ["model" ] = "tts-1" # Unify model value
125
107
unsupported_params = set (request_data .keys ()) - ALLOWED_PARAMS
126
108
if unsupported_params :
127
- logger .warning (f"忽略不支持的参数: { unsupported_params } / Ignoring unsupported parameters: { unsupported_params } " )
109
+ logger .warning (f"Ignoring unsupported parameters: { unsupported_params } " )
128
110
return {key : request_data [key ] for key in ALLOWED_PARAMS if key in request_data }
129
111
130
- # 统一错误响应 / Unified error response
112
+ # Unified error response
131
113
@app .exception_handler (Exception )
132
114
async def custom_exception_handler (request , exc ):
133
- """自定义异常处理 / Custom exception handler"""
115
+ """Custom exception handler"""
134
116
logger .error (f"Error: { str (exc )} " )
135
117
return JSONResponse (
136
118
status_code = getattr (exc , "status_code" , 500 ),
@@ -139,49 +121,36 @@ async def custom_exception_handler(request, exc):
139
121
140
122
@app .post ("/v1/audio/speech" )
141
123
async def generate_voice (request_data : Dict ):
142
- """处理语音合成请求 / Handle speech synthesis request"""
124
+ """Handle speech synthesis request"""
143
125
request_data = OpenAITTSRequest .validate_request (request_data )
144
126
request = OpenAITTSRequest (** request_data )
145
127
146
- logger .info (f"收到请求: text= { request . input [: 50 ] } ..., voice= { request . voice } , stream= { request . stream } / Received request: text={ request .input [: 50 ] } ..., voice={ request .voice } , stream={ request .stream } " )
128
+ logger .info (f"Received request: text={ request .input } ..., voice={ request .voice } , stream={ request .stream } " )
147
129
148
- # 验证音频格式
130
+ # Validate audio format
149
131
if request .response_format not in ALLOWED_FORMATS :
150
- raise HTTPException (400 , detail = f"不支持的音频格式 : { request .response_format } ,支持 : { ', ' .join (ALLOWED_FORMATS )} " )
132
+ raise HTTPException (400 , detail = f"Unsupported audio format : { request .response_format } , supported formats : { ', ' .join (ALLOWED_FORMATS )} " )
151
133
152
- # 加载指定语音的说话者嵌入 / Load speaker embedding for the specified voice
134
+ # Load speaker embedding for the specified voice
153
135
spk_emb = app .state .spk_emb_map .get (request .voice , app .state .spk_emb )
154
136
155
- # 推理参数 / Inference parameters
137
+ # Inference parameters
156
138
params_infer_main = {
157
139
"text" : [request .input ],
158
140
"stream" : request .stream ,
159
141
"lang" : None ,
160
- "skip_refine_text" : False ,
161
- "refine_text_only" : True ,
142
+ "skip_refine_text" : True , # Do not use text refinement
143
+ "refine_text_only" : False ,
162
144
"use_decoder" : True ,
163
145
"audio_seed" : 12345678 ,
164
- "text_seed" : 87654321 ,
165
- "do_text_normalization" : True ,
166
- "do_homophone_replacement" : False ,
146
+ # "text_seed": 87654321, # Random seed for text processing, used to control text refinement
147
+ "do_text_normalization" : True , # Perform text normalization
148
+ "do_homophone_replacement" : True , # Perform homophone replacement
167
149
}
168
-
169
- # 精炼文本参数 / Refine text parameters
170
- params_refine_text = app .state .chat .RefineTextParams (
171
- prompt = "" ,
172
- top_P = 0.7 ,
173
- top_K = 20 ,
174
- temperature = 0.7 ,
175
- repetition_penalty = 1.0 ,
176
- max_new_token = 384 ,
177
- min_new_token = 0 ,
178
- show_tqdm = True ,
179
- ensure_non_empty = True ,
180
- manual_seed = None ,
181
- )
182
- # 推理代码参数 / Inference code parameters
150
+
151
+ # Inference code parameters
183
152
params_infer_code = app .state .chat .InferCodeParams (
184
- #prompt=f"[speed_{int(request.speed * 10)}]", # 转换为 ChatTTS 支持的格式
153
+ #prompt=f"[speed_{int(request.speed * 10)}]", # Convert to format supported by ChatTTS
185
154
prompt = "[speed_5]" ,
186
155
top_P = 0.5 ,
187
156
top_K = 10 ,
@@ -202,36 +171,25 @@ async def generate_voice(request_data: Dict):
202
171
203
172
try :
204
173
async with app .state .model_lock :
205
- start_time = time .time ()
206
- # 第一步:单独精炼文本 / Step 1: Refine text separately
207
- refined_text = app .state .chat .infer (
208
- text = params_infer_main ["text" ],
209
- skip_refine_text = False ,
210
- refine_text_only = True # 只精炼文本,不生成语音 / Only refine text, do not generate speech
211
- )
212
- logger .info (f"Refined text: { refined_text } " )
213
- logger .info (f"Refined text time: { time .time () - start_time :.2f} 秒" )
214
-
215
- # 第二步:用精炼后的文本生成语音 / Step 2: Generate speech with refined text
216
174
wavs = app .state .chat .infer (
217
- text = params_infer_main ["text" ],
218
- stream = params_infer_main ["stream" ],
219
- lang = params_infer_main ["lang" ],
220
- skip_refine_text = True ,
221
- use_decoder = params_infer_main ["use_decoder" ],
222
- do_text_normalization = False ,
223
- do_homophone_replacement = params_infer_main ['do_homophone_replacement' ],
224
- params_refine_text = params_refine_text ,
175
+ text = params_infer_main ["text" ],
176
+ stream = params_infer_main ["stream" ],
177
+ lang = params_infer_main ["lang" ],
178
+ skip_refine_text = params_infer_main [ "skip_refine_text" ],
179
+ use_decoder = params_infer_main ["use_decoder" ],
180
+ do_text_normalization = params_infer_main [ "do_text_normalization" ] ,
181
+ do_homophone_replacement = params_infer_main ['do_homophone_replacement' ],
182
+ # params_refine_text = params_refine_text,
225
183
params_infer_code = params_infer_code ,
226
184
)
227
185
except Exception as e :
228
- raise HTTPException (500 , detail = f"语音合成失败 / Speech synthesis failed: { str (e )} " )
186
+ raise HTTPException (500 , detail = f"Speech synthesis failed: { str (e )} " )
229
187
230
188
def generate_wav_header (sample_rate = 24000 , bits_per_sample = 16 , channels = 1 ):
231
- """生成 WAV 文件头部(不指定数据长度) / Generate WAV file header (without data length)"""
189
+ """Generate WAV file header (without data length)"""
232
190
header = bytearray ()
233
191
header .extend (b"RIFF" )
234
- header .extend (b"\xFF \xFF \xFF \xFF " ) # 文件大小未知 / File size unknown
192
+ header .extend (b"\xFF \xFF \xFF \xFF " ) # File size unknown
235
193
header .extend (b"WAVEfmt " )
236
194
header .extend ((16 ).to_bytes (4 , "little" )) # fmt chunk size
237
195
header .extend ((1 ).to_bytes (2 , "little" )) # PCM format
@@ -243,34 +201,34 @@ def generate_wav_header(sample_rate=24000, bits_per_sample=16, channels=1):
243
201
header .extend ((block_align ).to_bytes (2 , "little" )) # Block align
244
202
header .extend ((bits_per_sample ).to_bytes (2 , "little" )) # Bits per sample
245
203
header .extend (b"data" )
246
- header .extend (b"\xFF \xFF \xFF \xFF " ) # 数据长度未知 / Data size unknown
204
+ header .extend (b"\xFF \xFF \xFF \xFF " ) # Data size unknown
247
205
return bytes (header )
248
206
249
- # 处理音频输出格式 / Handle audio output format
207
+ # Handle audio output format
250
208
def convert_audio (wav , format ):
251
- """转换音频格式 / Convert audio format"""
209
+ """Convert audio format"""
252
210
if format == "mp3" :
253
211
return pcm_arr_to_mp3_view (wav )
254
212
elif format == "wav" :
255
- return pcm_arr_to_wav_view (wav , include_header = False ) # 流式时不含头部 / No header in streaming
213
+ return pcm_arr_to_wav_view (wav , include_header = False ) # No header in streaming
256
214
elif format == "ogg" :
257
215
return pcm_arr_to_ogg_view (wav )
258
216
return pcm_arr_to_mp3_view (wav )
259
217
260
- # 返回流式输出音频数据
218
+ # Return streaming audio data
261
219
if request .stream :
262
220
first_chunk = True
263
221
async def audio_stream ():
264
222
nonlocal first_chunk
265
223
for wav in wavs :
266
224
if request .response_format == "wav" and first_chunk :
267
- yield generate_wav_header () # 发送 WAV 头部 / Send WAV header
225
+ yield generate_wav_header () # Send WAV header
268
226
first_chunk = False
269
227
yield convert_audio (wav , request .response_format )
270
228
media_type = "audio/wav" if request .response_format == "wav" else "audio/mpeg"
271
229
return StreamingResponse (audio_stream (), media_type = media_type )
272
230
273
- # 直接返回音频文件 / Return audio file directly
231
+ # Return audio file directly
274
232
if request .response_format == 'wav' :
275
233
music_data = pcm_arr_to_wav_view (wavs [0 ])
276
234
else :
@@ -282,5 +240,5 @@ async def audio_stream():
282
240
283
241
@app .get ("/health" )
284
242
async def health_check ():
285
- """健康检查端点 / Health check endpoint"""
243
+ """Health check endpoint"""
286
244
return {"status" : "healthy" , "model_loaded" : bool (app .state .chat )}
0 commit comments