Skip to content

Commit 7e75a50

Browse files
committed
delete skip_refine_text code.
update #
1 parent b7263a5 commit 7e75a50

File tree

3 files changed

+1477
-1285
lines changed

3 files changed

+1477
-1285
lines changed

examples/api/openai_api.ipynb

Lines changed: 1369 additions & 1118 deletions
Large diffs are not rendered by default.

examples/api/openai_api.py

Lines changed: 78 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,16 @@
11
"""
22
openai_api.py
3-
这个模块实现了一个基于 FastAPI 的语音合成 API,兼容 OpenAI 的接口规范。
43
This module implements a FastAPI-based text-to-speech API compatible with OpenAI's interface specification.
54
6-
主要功能和改进:
75
Main features and improvements:
8-
- 使用 app.state 管理全局状态,确保线程安全
9-
Use app.state to manage global state, ensuring thread safety
10-
- 添加异常处理和统一的错误响应,提升稳定性
11-
Add exception handling and unified error responses to improve stability
12-
- 支持多种语音选择和多种音频格式,增加灵活性
13-
Support multiple voice options and audio formats for greater flexibility
14-
- 增加输入验证,确保请求参数的合法性
15-
Add input validation to ensure the validity of request parameters
16-
- 支持更多 OpenAI TTS 参数(如 speed),提供更丰富的功能
17-
Support additional OpenAI TTS parameters (e.g., speed) for richer functionality
18-
- 实现了健康检查端点,便于监控服务状态
19-
Implement health check endpoint for easy service status monitoring
20-
- 使用异步锁(asyncio.Lock)管理模型访问,提升并发性能
21-
Use asyncio.Lock to manage model access, improving concurrency performance
22-
- 加载和管理说话者嵌入文件,支持个性化语音合成
23-
Load and manage speaker embedding files to support personalized speech synthesis
6+
- Use app.state to manage global state, ensuring thread safety
7+
- Add exception handling and unified error responses to improve stability
8+
- Support multiple voice options and audio formats for greater flexibility
9+
- Add input validation to ensure the validity of request parameters
10+
- Support additional OpenAI TTS parameters (e.g., speed) for richer functionality
11+
- Implement health check endpoint for easy service status monitoring
12+
- Use asyncio.Lock to manage model access, improving concurrency performance
13+
- Load and manage speaker embedding files to support personalized speech synthesis
2414
"""
2515
import io
2616
import os
@@ -33,104 +23,96 @@
3323
from pydantic import BaseModel, Field
3424
import torch
3525

36-
# 跨平台兼容性设置 / Cross-platform compatibility settings
26+
# Cross-platform compatibility settings
3727
if sys.platform == "darwin":
3828
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
3929

40-
# 设置工作目录并添加到系统路径 / Set working directory and add to system path
30+
# Set working directory and add to system path
4131
now_dir = os.getcwd()
4232
sys.path.append(now_dir)
4333

44-
# 导入必要的模块 / Import necessary modules
34+
# Import necessary modules
4535
import ChatTTS
4636
from tools.audio import pcm_arr_to_mp3_view, pcm_arr_to_ogg_view, pcm_arr_to_wav_view
4737
from tools.logger import get_logger
4838
from tools.normalizer.en import normalizer_en_nemo_text
4939
from tools.normalizer.zh import normalizer_zh_tn
5040

51-
# 初始化日志记录器 / Initialize logger
41+
# Initialize logger
5242
logger = get_logger("Command")
5343

54-
# 初始化 FastAPI 应用 / Initialize FastAPI application
44+
# Initialize FastAPI application
5545
app = FastAPI()
5646

57-
# 语音映射表 / Voice mapping table
58-
# 下载稳定音色 / Download stable voices:
59-
# 魔塔社区 https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker
60-
# HuggingFace https://huggingface.co/spaces/taa/ChatTTS_Speaker
47+
# Voice mapping table
48+
# Download stable voices:
49+
# ModelScope Community: https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker
50+
# HuggingFace: https://huggingface.co/spaces/taa/ChatTTS_Speaker
6151
VOICE_MAP = {
6252
"default": "1528.pt",
6353
"alloy": "1384.pt",
6454
"echo": "2443.pt",
65-
"speak_1000": "1000.pt",
66-
"speak_1509": "1509.pt",
67-
"speak_1996": "1996.pt",
68-
"speak_2115": "2115.pt",
69-
"speak_2166": "2166.pt",
70-
"speak_2218": "2218.pt"
7155
}
7256

73-
# 允许的音频格式 / Allowed audio formats
57+
# Allowed audio formats
7458
ALLOWED_FORMATS = {"mp3", "wav", "ogg"}
7559

7660
@app.on_event("startup")
7761
async def startup_event():
78-
"""应用启动时加载 ChatTTS 模型及默认说话者嵌入
79-
Load ChatTTS model and default speaker embedding when the application starts"""
80-
# 初始化 ChatTTS 和异步锁 / Initialize ChatTTS and async lock
62+
"""Load ChatTTS model and default speaker embedding when the application starts"""
63+
# Initialize ChatTTS and async lock
8164
app.state.chat = ChatTTS.Chat(get_logger("ChatTTS"))
82-
app.state.model_lock = asyncio.Lock() # 使用异步锁替代线程锁 / Use async lock instead of thread lock
65+
app.state.model_lock = asyncio.Lock() # Use async lock instead of thread lock
8366

84-
# 注册文本规范化器 / Register text normalizers
67+
# Register text normalizers
8568
app.state.chat.normalizer.register("en", normalizer_en_nemo_text())
8669
app.state.chat.normalizer.register("zh", normalizer_zh_tn())
8770

88-
logger.info("正在初始化 ChatTTS... / Initializing ChatTTS...")
71+
logger.info("Initializing ChatTTS...")
8972
if app.state.chat.load(source="huggingface"):
90-
logger.info("模型加载成功。 / Model loaded successfully.")
73+
logger.info("Model loaded successfully.")
9174
else:
92-
logger.error("模型加载失败,退出应用。 / Model loading failed, exiting application.")
75+
logger.error("Model loading failed, exiting application.")
9376
raise RuntimeError("Failed to load ChatTTS model")
9477

95-
# 加载默认说话者嵌入 / Load default speaker embedding
96-
# 在启动时预加载所有支持的说话者嵌入到内存中,避免运行时重复加载
78+
# Load default speaker embedding
79+
# Preload all supported speaker embeddings into memory at startup to avoid repeated loading during runtime
9780
app.state.spk_emb_map = {}
9881
for voice, spk_path in VOICE_MAP.items():
9982
if os.path.exists(spk_path):
10083
app.state.spk_emb_map[voice] = torch.load(spk_path, map_location=torch.device("cpu"))
101-
logger.info(f"预加载说话者嵌入: {voice} -> {spk_path}")
84+
logger.info(f"Preloading speaker embedding: {voice} -> {spk_path}")
10285
else:
103-
logger.warning(f"未找到 {spk_path},跳过预加载")
104-
app.state.spk_emb = app.state.spk_emb_map.get("default") # 默认嵌入
86+
logger.warning(f"Speaker embedding not found: {spk_path}, skipping preload")
87+
app.state.spk_emb = app.state.spk_emb_map.get("default") # Default embedding
10588

106-
# 请求参数白名单 / Request parameter whitelist
89+
# Request parameter whitelist
10790
ALLOWED_PARAMS = {"model", "input", "voice", "response_format", "speed", "stream", "output_format"}
10891

10992
class OpenAITTSRequest(BaseModel):
110-
"""OpenAI TTS 请求数据模型 / OpenAI TTS request data model"""
111-
model: str = Field(..., description="语音合成模型,固定为 'tts-1' / Speech synthesis model, fixed as 'tts-1'")
112-
input: str = Field(..., description="待合成的文本内容 / Text content to synthesize", max_length=2048) # 限制长度 / Length limit
113-
voice: Optional[str] = Field("default", description="语音选择,支持: default, alloy, echo / Voice selection, supports: default, alloy, echo")
114-
response_format: Optional[str] = Field("mp3", description="音频格式: mp3, wav, ogg / Audio format: mp3, wav, ogg")
115-
speed: Optional[float] = Field(1.0, ge=0.5, le=2.0, description="语速,范围 0.5-2.0 / Speed, range 0.5-2.0")
116-
stream: Optional[bool] = Field(False, description="是否流式传输 / Whether to stream")
117-
output_format: Optional[str] = "mp3" # 可选格式:mp3, wav, ogg / Optional formats: mp3, wav, ogg
118-
extra_params: Dict[str, Optional[str]] = Field(default_factory=dict, description="不支持的额外参数 / Unsupported extra parameters")
93+
"""OpenAI TTS request data model"""
94+
model: str = Field(..., description="Speech synthesis model, fixed as 'tts-1'")
95+
input: str = Field(..., description="Text content to synthesize", max_length=2048) # Length limit
96+
voice: Optional[str] = Field("default", description="Voice selection, supports: default, alloy, echo")
97+
response_format: Optional[str] = Field("mp3", description="Audio format: mp3, wav, ogg")
98+
speed: Optional[float] = Field(1.0, ge=0.5, le=2.0, description="Speed, range 0.5-2.0")
99+
stream: Optional[bool] = Field(False, description="Whether to stream")
100+
output_format: Optional[str] = "mp3" # Optional formats: mp3, wav, ogg
101+
extra_params: Dict[str, Optional[str]] = Field(default_factory=dict, description="Unsupported extra parameters")
119102

120103
@classmethod
121104
def validate_request(cls, request_data: Dict):
122-
"""过滤不支持的请求参数,并统一 model 值为 'tts-1'
123-
Filter unsupported request parameters and unify model value to 'tts-1'"""
124-
request_data["model"] = "tts-1" # 统一 model 值 / Unify model value
105+
"""Filter unsupported request parameters and unify model value to 'tts-1'"""
106+
request_data["model"] = "tts-1" # Unify model value
125107
unsupported_params = set(request_data.keys()) - ALLOWED_PARAMS
126108
if unsupported_params:
127-
logger.warning(f"忽略不支持的参数: {unsupported_params} / Ignoring unsupported parameters: {unsupported_params}")
109+
logger.warning(f"Ignoring unsupported parameters: {unsupported_params}")
128110
return {key: request_data[key] for key in ALLOWED_PARAMS if key in request_data}
129111

130-
# 统一错误响应 / Unified error response
112+
# Unified error response
131113
@app.exception_handler(Exception)
132114
async def custom_exception_handler(request, exc):
133-
"""自定义异常处理 / Custom exception handler"""
115+
"""Custom exception handler"""
134116
logger.error(f"Error: {str(exc)}")
135117
return JSONResponse(
136118
status_code=getattr(exc, "status_code", 500),
@@ -139,49 +121,36 @@ async def custom_exception_handler(request, exc):
139121

140122
@app.post("/v1/audio/speech")
141123
async def generate_voice(request_data: Dict):
142-
"""处理语音合成请求 / Handle speech synthesis request"""
124+
"""Handle speech synthesis request"""
143125
request_data = OpenAITTSRequest.validate_request(request_data)
144126
request = OpenAITTSRequest(**request_data)
145127

146-
logger.info(f"收到请求: text={request.input[:50]}..., voice={request.voice}, stream={request.stream} / Received request: text={request.input[:50]}..., voice={request.voice}, stream={request.stream}")
128+
logger.info(f"Received request: text={request.input}..., voice={request.voice}, stream={request.stream}")
147129

148-
# 验证音频格式
130+
# Validate audio format
149131
if request.response_format not in ALLOWED_FORMATS:
150-
raise HTTPException(400, detail=f"不支持的音频格式: {request.response_format},支持: {', '.join(ALLOWED_FORMATS)}")
132+
raise HTTPException(400, detail=f"Unsupported audio format: {request.response_format}, supported formats: {', '.join(ALLOWED_FORMATS)}")
151133

152-
# 加载指定语音的说话者嵌入 / Load speaker embedding for the specified voice
134+
# Load speaker embedding for the specified voice
153135
spk_emb = app.state.spk_emb_map.get(request.voice, app.state.spk_emb)
154136

155-
# 推理参数 / Inference parameters
137+
# Inference parameters
156138
params_infer_main = {
157139
"text": [request.input],
158140
"stream": request.stream,
159141
"lang": None,
160-
"skip_refine_text": False,
161-
"refine_text_only": True,
142+
"skip_refine_text": True, # Do not use text refinement
143+
"refine_text_only": False,
162144
"use_decoder": True,
163145
"audio_seed": 12345678,
164-
"text_seed": 87654321,
165-
"do_text_normalization": True,
166-
"do_homophone_replacement": False,
146+
# "text_seed": 87654321, # Random seed for text processing, used to control text refinement
147+
"do_text_normalization": True, # Perform text normalization
148+
"do_homophone_replacement": True, # Perform homophone replacement
167149
}
168-
169-
# 精炼文本参数 / Refine text parameters
170-
params_refine_text = app.state.chat.RefineTextParams(
171-
prompt="",
172-
top_P=0.7,
173-
top_K=20,
174-
temperature=0.7,
175-
repetition_penalty=1.0,
176-
max_new_token=384,
177-
min_new_token=0,
178-
show_tqdm=True,
179-
ensure_non_empty=True,
180-
manual_seed=None,
181-
)
182-
# 推理代码参数 / Inference code parameters
150+
151+
# Inference code parameters
183152
params_infer_code = app.state.chat.InferCodeParams(
184-
#prompt=f"[speed_{int(request.speed * 10)}]", # 转换为 ChatTTS 支持的格式
153+
#prompt=f"[speed_{int(request.speed * 10)}]", # Convert to format supported by ChatTTS
185154
prompt="[speed_5]",
186155
top_P=0.5,
187156
top_K=10,
@@ -202,36 +171,25 @@ async def generate_voice(request_data: Dict):
202171

203172
try:
204173
async with app.state.model_lock:
205-
start_time = time.time()
206-
# 第一步:单独精炼文本 / Step 1: Refine text separately
207-
refined_text = app.state.chat.infer(
208-
text=params_infer_main["text"],
209-
skip_refine_text=False,
210-
refine_text_only=True # 只精炼文本,不生成语音 / Only refine text, do not generate speech
211-
)
212-
logger.info(f"Refined text: {refined_text}")
213-
logger.info(f"Refined text time: {time.time() - start_time:.2f} 秒")
214-
215-
# 第二步:用精炼后的文本生成语音 / Step 2: Generate speech with refined text
216174
wavs = app.state.chat.infer(
217-
text=params_infer_main["text"],
218-
stream=params_infer_main["stream"],
219-
lang=params_infer_main["lang"],
220-
skip_refine_text=True,
221-
use_decoder=params_infer_main["use_decoder"],
222-
do_text_normalization=False,
223-
do_homophone_replacement=params_infer_main['do_homophone_replacement'],
224-
params_refine_text=params_refine_text,
175+
text = params_infer_main["text"],
176+
stream = params_infer_main["stream"],
177+
lang = params_infer_main["lang"],
178+
skip_refine_text = params_infer_main["skip_refine_text"],
179+
use_decoder = params_infer_main["use_decoder"],
180+
do_text_normalization = params_infer_main["do_text_normalization"],
181+
do_homophone_replacement = params_infer_main['do_homophone_replacement'],
182+
# params_refine_text = params_refine_text,
225183
params_infer_code=params_infer_code,
226184
)
227185
except Exception as e:
228-
raise HTTPException(500, detail=f"语音合成失败 / Speech synthesis failed: {str(e)}")
186+
raise HTTPException(500, detail=f"Speech synthesis failed: {str(e)}")
229187

230188
def generate_wav_header(sample_rate=24000, bits_per_sample=16, channels=1):
231-
"""生成 WAV 文件头部(不指定数据长度) / Generate WAV file header (without data length)"""
189+
"""Generate WAV file header (without data length)"""
232190
header = bytearray()
233191
header.extend(b"RIFF")
234-
header.extend(b"\xFF\xFF\xFF\xFF") # 文件大小未知 / File size unknown
192+
header.extend(b"\xFF\xFF\xFF\xFF") # File size unknown
235193
header.extend(b"WAVEfmt ")
236194
header.extend((16).to_bytes(4, "little")) # fmt chunk size
237195
header.extend((1).to_bytes(2, "little")) # PCM format
@@ -243,34 +201,34 @@ def generate_wav_header(sample_rate=24000, bits_per_sample=16, channels=1):
243201
header.extend((block_align).to_bytes(2, "little")) # Block align
244202
header.extend((bits_per_sample).to_bytes(2, "little")) # Bits per sample
245203
header.extend(b"data")
246-
header.extend(b"\xFF\xFF\xFF\xFF") # 数据长度未知 / Data size unknown
204+
header.extend(b"\xFF\xFF\xFF\xFF") # Data size unknown
247205
return bytes(header)
248206

249-
# 处理音频输出格式 / Handle audio output format
207+
# Handle audio output format
250208
def convert_audio(wav, format):
251-
"""转换音频格式 / Convert audio format"""
209+
"""Convert audio format"""
252210
if format == "mp3":
253211
return pcm_arr_to_mp3_view(wav)
254212
elif format == "wav":
255-
return pcm_arr_to_wav_view(wav, include_header=False) # 流式时不含头部 / No header in streaming
213+
return pcm_arr_to_wav_view(wav, include_header=False) # No header in streaming
256214
elif format == "ogg":
257215
return pcm_arr_to_ogg_view(wav)
258216
return pcm_arr_to_mp3_view(wav)
259217

260-
# 返回流式输出音频数据
218+
# Return streaming audio data
261219
if request.stream:
262220
first_chunk = True
263221
async def audio_stream():
264222
nonlocal first_chunk
265223
for wav in wavs:
266224
if request.response_format == "wav" and first_chunk:
267-
yield generate_wav_header() # 发送 WAV 头部 / Send WAV header
225+
yield generate_wav_header() # Send WAV header
268226
first_chunk = False
269227
yield convert_audio(wav, request.response_format)
270228
media_type = "audio/wav" if request.response_format == "wav" else "audio/mpeg"
271229
return StreamingResponse(audio_stream(), media_type=media_type)
272230

273-
# 直接返回音频文件 / Return audio file directly
231+
# Return audio file directly
274232
if request.response_format == 'wav':
275233
music_data = pcm_arr_to_wav_view(wavs[0])
276234
else:
@@ -282,5 +240,5 @@ async def audio_stream():
282240

283241
@app.get("/health")
284242
async def health_check():
285-
"""健康检查端点 / Health check endpoint"""
243+
"""Health check endpoint"""
286244
return {"status": "healthy", "model_loaded": bool(app.state.chat)}

0 commit comments

Comments
 (0)