Skip to content

Commit 946263d

Browse files
authored
Merge branch 'dev' into TF-conversion
2 parents 08ffcd0 + 4090ff2 commit 946263d

File tree

6 files changed

+2088
-8
lines changed

6 files changed

+2088
-8
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
# ignore jupyter notebooks in the language bar on github
22
**/*.ipynb linguist-vendored
3+
*.ipynb

examples/api/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ pip install -r examples/api/requirements.txt
1414
fastapi dev examples/api/main.py --host 0.0.0.0 --port 8000
1515
```
1616

17+
## Run openAI_API server
18+
19+
```
20+
fastapi dev examples/api/openai_api.py --host 0.0.0.0 --port 8000
21+
```
1722
## Generate audio using requests
1823

1924
```

examples/api/openai_api.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
"""
2+
openai_api.py
3+
This module implements a FastAPI-based text-to-speech API compatible with OpenAI's interface specification.
4+
5+
Main features and improvements:
6+
- Use app.state to manage global state, ensuring thread safety
7+
- Add exception handling and unified error responses to improve stability
8+
- Support multiple voice options and audio formats for greater flexibility
9+
- Add input validation to ensure the validity of request parameters
10+
- Support additional OpenAI TTS parameters (e.g., speed) for richer functionality
11+
- Implement health check endpoint for easy service status monitoring
12+
- Use asyncio.Lock to manage model access, improving concurrency performance
13+
- Load and manage speaker embedding files to support personalized speech synthesis
14+
"""
15+
16+
import io
17+
import os
18+
import sys
19+
import asyncio
20+
import time
21+
from typing import Optional, Dict
22+
from fastapi import FastAPI, HTTPException
23+
from fastapi.responses import StreamingResponse, JSONResponse
24+
from pydantic import BaseModel, Field
25+
import torch
26+
27+
# Cross-platform compatibility settings
28+
if sys.platform == "darwin":
29+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
30+
31+
# Set working directory and add to system path
32+
now_dir = os.getcwd()
33+
sys.path.append(now_dir)
34+
35+
# Import necessary modules
36+
import ChatTTS
37+
from tools.audio import pcm_arr_to_mp3_view, pcm_arr_to_ogg_view, pcm_arr_to_wav_view
38+
from tools.logger import get_logger
39+
from tools.normalizer.en import normalizer_en_nemo_text
40+
from tools.normalizer.zh import normalizer_zh_tn
41+
42+
# Initialize logger
43+
logger = get_logger("Command")
44+
45+
# Initialize FastAPI application
46+
app = FastAPI()
47+
48+
# Voice mapping table
49+
# Download stable voices:
50+
# ModelScope Community: https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker
51+
# HuggingFace: https://huggingface.co/spaces/taa/ChatTTS_Speaker
52+
VOICE_MAP = {
53+
"default": "1528.pt",
54+
"alloy": "1384.pt",
55+
"echo": "2443.pt",
56+
}
57+
58+
# Allowed audio formats
59+
ALLOWED_FORMATS = {"mp3", "wav", "ogg"}
60+
61+
62+
@app.on_event("startup")
63+
async def startup_event():
64+
"""Load ChatTTS model and default speaker embedding when the application starts"""
65+
# Initialize ChatTTS and async lock
66+
app.state.chat = ChatTTS.Chat(get_logger("ChatTTS"))
67+
app.state.model_lock = asyncio.Lock() # Use async lock instead of thread lock
68+
69+
# Register text normalizers
70+
app.state.chat.normalizer.register("en", normalizer_en_nemo_text())
71+
app.state.chat.normalizer.register("zh", normalizer_zh_tn())
72+
73+
logger.info("Initializing ChatTTS...")
74+
if app.state.chat.load(source="huggingface"):
75+
logger.info("Model loaded successfully.")
76+
else:
77+
logger.error("Model loading failed, exiting application.")
78+
raise RuntimeError("Failed to load ChatTTS model")
79+
80+
# Load default speaker embedding
81+
# Preload all supported speaker embeddings into memory at startup to avoid repeated loading during runtime
82+
app.state.spk_emb_map = {}
83+
for voice, spk_path in VOICE_MAP.items():
84+
if os.path.exists(spk_path):
85+
app.state.spk_emb_map[voice] = torch.load(
86+
spk_path, map_location=torch.device("cpu")
87+
)
88+
logger.info(f"Preloading speaker embedding: {voice} -> {spk_path}")
89+
else:
90+
logger.warning(f"Speaker embedding not found: {spk_path}, skipping preload")
91+
app.state.spk_emb = app.state.spk_emb_map.get("default") # Default embedding
92+
93+
94+
# Request parameter whitelist
95+
ALLOWED_PARAMS = {
96+
"model",
97+
"input",
98+
"voice",
99+
"response_format",
100+
"speed",
101+
"stream",
102+
"output_format",
103+
}
104+
105+
106+
class OpenAITTSRequest(BaseModel):
107+
"""OpenAI TTS request data model"""
108+
109+
model: str = Field(..., description="Speech synthesis model, fixed as 'tts-1'")
110+
input: str = Field(
111+
..., description="Text content to synthesize", max_length=2048
112+
) # Length limit
113+
voice: Optional[str] = Field(
114+
"default", description="Voice selection, supports: default, alloy, echo"
115+
)
116+
response_format: Optional[str] = Field(
117+
"mp3", description="Audio format: mp3, wav, ogg"
118+
)
119+
speed: Optional[float] = Field(
120+
1.0, ge=0.5, le=2.0, description="Speed, range 0.5-2.0"
121+
)
122+
stream: Optional[bool] = Field(False, description="Whether to stream")
123+
output_format: Optional[str] = "mp3" # Optional formats: mp3, wav, ogg
124+
extra_params: Dict[str, Optional[str]] = Field(
125+
default_factory=dict, description="Unsupported extra parameters"
126+
)
127+
128+
@classmethod
129+
def validate_request(cls, request_data: Dict):
130+
"""Filter unsupported request parameters and unify model value to 'tts-1'"""
131+
request_data["model"] = "tts-1" # Unify model value
132+
unsupported_params = set(request_data.keys()) - ALLOWED_PARAMS
133+
if unsupported_params:
134+
logger.warning(f"Ignoring unsupported parameters: {unsupported_params}")
135+
return {key: request_data[key] for key in ALLOWED_PARAMS if key in request_data}
136+
137+
138+
# Unified error response
139+
@app.exception_handler(Exception)
140+
async def custom_exception_handler(request, exc):
141+
"""Custom exception handler"""
142+
logger.error(f"Error: {str(exc)}")
143+
return JSONResponse(
144+
status_code=getattr(exc, "status_code", 500),
145+
content={"error": {"message": str(exc), "type": exc.__class__.__name__}},
146+
)
147+
148+
149+
@app.post("/v1/audio/speech")
150+
async def generate_voice(request_data: Dict):
151+
"""Handle speech synthesis request"""
152+
request_data = OpenAITTSRequest.validate_request(request_data)
153+
request = OpenAITTSRequest(**request_data)
154+
155+
logger.info(
156+
f"Received request: text={request.input}..., voice={request.voice}, stream={request.stream}"
157+
)
158+
159+
# Validate audio format
160+
if request.response_format not in ALLOWED_FORMATS:
161+
raise HTTPException(
162+
400,
163+
detail=f"Unsupported audio format: {request.response_format}, supported formats: {', '.join(ALLOWED_FORMATS)}",
164+
)
165+
166+
# Load speaker embedding for the specified voice
167+
spk_emb = app.state.spk_emb_map.get(request.voice, app.state.spk_emb)
168+
169+
# Inference parameters
170+
params_infer_main = {
171+
"text": [request.input],
172+
"stream": request.stream,
173+
"lang": None,
174+
"skip_refine_text": True, # Do not use text refinement
175+
"refine_text_only": False,
176+
"use_decoder": True,
177+
"audio_seed": 12345678,
178+
# "text_seed": 87654321, # Random seed for text processing, used to control text refinement
179+
"do_text_normalization": True, # Perform text normalization
180+
"do_homophone_replacement": True, # Perform homophone replacement
181+
}
182+
183+
# Inference code parameters
184+
params_infer_code = app.state.chat.InferCodeParams(
185+
# prompt=f"[speed_{int(request.speed * 10)}]", # Convert to format supported by ChatTTS
186+
prompt="[speed_5]",
187+
top_P=0.5,
188+
top_K=10,
189+
temperature=0.1,
190+
repetition_penalty=1.1,
191+
max_new_token=2048,
192+
min_new_token=0,
193+
show_tqdm=True,
194+
ensure_non_empty=True,
195+
manual_seed=42,
196+
spk_emb=spk_emb,
197+
spk_smp=None,
198+
txt_smp=None,
199+
stream_batch=24,
200+
stream_speed=12000,
201+
pass_first_n_batches=2,
202+
)
203+
204+
try:
205+
async with app.state.model_lock:
206+
wavs = app.state.chat.infer(
207+
text=params_infer_main["text"],
208+
stream=params_infer_main["stream"],
209+
lang=params_infer_main["lang"],
210+
skip_refine_text=params_infer_main["skip_refine_text"],
211+
use_decoder=params_infer_main["use_decoder"],
212+
do_text_normalization=params_infer_main["do_text_normalization"],
213+
do_homophone_replacement=params_infer_main["do_homophone_replacement"],
214+
# params_refine_text = params_refine_text,
215+
params_infer_code=params_infer_code,
216+
)
217+
except Exception as e:
218+
raise HTTPException(500, detail=f"Speech synthesis failed: {str(e)}")
219+
220+
def generate_wav_header(sample_rate=24000, bits_per_sample=16, channels=1):
221+
"""Generate WAV file header (without data length)"""
222+
header = bytearray()
223+
header.extend(b"RIFF")
224+
header.extend(b"\xff\xff\xff\xff") # File size unknown
225+
header.extend(b"WAVEfmt ")
226+
header.extend((16).to_bytes(4, "little")) # fmt chunk size
227+
header.extend((1).to_bytes(2, "little")) # PCM format
228+
header.extend((channels).to_bytes(2, "little")) # Channels
229+
header.extend((sample_rate).to_bytes(4, "little")) # Sample rate
230+
byte_rate = sample_rate * channels * bits_per_sample // 8
231+
header.extend((byte_rate).to_bytes(4, "little")) # Byte rate
232+
block_align = channels * bits_per_sample // 8
233+
header.extend((block_align).to_bytes(2, "little")) # Block align
234+
header.extend((bits_per_sample).to_bytes(2, "little")) # Bits per sample
235+
header.extend(b"data")
236+
header.extend(b"\xff\xff\xff\xff") # Data size unknown
237+
return bytes(header)
238+
239+
# Handle audio output format
240+
def convert_audio(wav, format):
241+
"""Convert audio format"""
242+
if format == "mp3":
243+
return pcm_arr_to_mp3_view(wav)
244+
elif format == "wav":
245+
return pcm_arr_to_wav_view(
246+
wav, include_header=False
247+
) # No header in streaming
248+
elif format == "ogg":
249+
return pcm_arr_to_ogg_view(wav)
250+
return pcm_arr_to_mp3_view(wav)
251+
252+
# Return streaming audio data
253+
if request.stream:
254+
first_chunk = True
255+
256+
async def audio_stream():
257+
nonlocal first_chunk
258+
for wav in wavs:
259+
if request.response_format == "wav" and first_chunk:
260+
yield generate_wav_header() # Send WAV header
261+
first_chunk = False
262+
yield convert_audio(wav, request.response_format)
263+
264+
media_type = "audio/wav" if request.response_format == "wav" else "audio/mpeg"
265+
return StreamingResponse(audio_stream(), media_type=media_type)
266+
267+
# Return audio file directly
268+
if request.response_format == "wav":
269+
music_data = pcm_arr_to_wav_view(wavs[0])
270+
else:
271+
music_data = convert_audio(wavs[0], request.response_format)
272+
273+
return StreamingResponse(
274+
io.BytesIO(music_data),
275+
media_type="audio/mpeg",
276+
headers={
277+
"Content-Disposition": f"attachment; filename=output.{request.response_format}"
278+
},
279+
)
280+
281+
282+
@app.get("/health")
283+
async def health_check():
284+
"""Health check endpoint"""
285+
return {"status": "healthy", "model_loaded": bool(app.state.chat)}

openai_api.ipynb

Lines changed: 1719 additions & 0 deletions
Large diffs are not rendered by default.

tools/audio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .av import load_audio
2-
from .pcm import pcm_arr_to_mp3_view
2+
from .pcm import pcm_arr_to_mp3_view, pcm_arr_to_ogg_view, pcm_arr_to_wav_view
33
from .ffmpeg import has_ffmpeg_installed
44
from .np import float_to_int16

0 commit comments

Comments
 (0)