Skip to content

Commit d582fd5

Browse files
authored
feat(example): add new API (#913)
Add FastAPI-based OpenAI-compatible Text-to-Speech API and Audio Format Conversion Tools
1 parent 4c201cd commit d582fd5

File tree

6 files changed

+2042
-8
lines changed

6 files changed

+2042
-8
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
# ignore jupyter notebooks in the language bar on github
22
**/*.ipynb linguist-vendored
3+
*.ipynb

examples/api/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ pip install -r examples/api/requirements.txt
1414
fastapi dev examples/api/main.py --host 0.0.0.0 --port 8000
1515
```
1616

17+
## Run openAI_API server
18+
19+
```
20+
fastapi dev examples/api/openai_api.py --host 0.0.0.0 --port 8000
21+
```
1722
## Generate audio using requests
1823

1924
```

examples/api/openai_api.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
"""
2+
openai_api.py
3+
This module implements a FastAPI-based text-to-speech API compatible with OpenAI's interface specification.
4+
5+
Main features and improvements:
6+
- Use app.state to manage global state, ensuring thread safety
7+
- Add exception handling and unified error responses to improve stability
8+
- Support multiple voice options and audio formats for greater flexibility
9+
- Add input validation to ensure the validity of request parameters
10+
- Support additional OpenAI TTS parameters (e.g., speed) for richer functionality
11+
- Implement health check endpoint for easy service status monitoring
12+
- Use asyncio.Lock to manage model access, improving concurrency performance
13+
- Load and manage speaker embedding files to support personalized speech synthesis
14+
"""
15+
import io
16+
import os
17+
import sys
18+
import asyncio
19+
import time
20+
from typing import Optional, Dict
21+
from fastapi import FastAPI, HTTPException
22+
from fastapi.responses import StreamingResponse, JSONResponse
23+
from pydantic import BaseModel, Field
24+
import torch
25+
26+
# Cross-platform compatibility settings
27+
if sys.platform == "darwin":
28+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
29+
30+
# Set working directory and add to system path
31+
now_dir = os.getcwd()
32+
sys.path.append(now_dir)
33+
34+
# Import necessary modules
35+
import ChatTTS
36+
from tools.audio import pcm_arr_to_mp3_view, pcm_arr_to_ogg_view, pcm_arr_to_wav_view
37+
from tools.logger import get_logger
38+
from tools.normalizer.en import normalizer_en_nemo_text
39+
from tools.normalizer.zh import normalizer_zh_tn
40+
41+
# Initialize logger
42+
logger = get_logger("Command")
43+
44+
# Initialize FastAPI application
45+
app = FastAPI()
46+
47+
# Voice mapping table
48+
# Download stable voices:
49+
# ModelScope Community: https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker
50+
# HuggingFace: https://huggingface.co/spaces/taa/ChatTTS_Speaker
51+
VOICE_MAP = {
52+
"default": "1528.pt",
53+
"alloy": "1384.pt",
54+
"echo": "2443.pt",
55+
}
56+
57+
# Allowed audio formats
58+
ALLOWED_FORMATS = {"mp3", "wav", "ogg"}
59+
60+
@app.on_event("startup")
61+
async def startup_event():
62+
"""Load ChatTTS model and default speaker embedding when the application starts"""
63+
# Initialize ChatTTS and async lock
64+
app.state.chat = ChatTTS.Chat(get_logger("ChatTTS"))
65+
app.state.model_lock = asyncio.Lock() # Use async lock instead of thread lock
66+
67+
# Register text normalizers
68+
app.state.chat.normalizer.register("en", normalizer_en_nemo_text())
69+
app.state.chat.normalizer.register("zh", normalizer_zh_tn())
70+
71+
logger.info("Initializing ChatTTS...")
72+
if app.state.chat.load(source="huggingface"):
73+
logger.info("Model loaded successfully.")
74+
else:
75+
logger.error("Model loading failed, exiting application.")
76+
raise RuntimeError("Failed to load ChatTTS model")
77+
78+
# Load default speaker embedding
79+
# Preload all supported speaker embeddings into memory at startup to avoid repeated loading during runtime
80+
app.state.spk_emb_map = {}
81+
for voice, spk_path in VOICE_MAP.items():
82+
if os.path.exists(spk_path):
83+
app.state.spk_emb_map[voice] = torch.load(spk_path, map_location=torch.device("cpu"))
84+
logger.info(f"Preloading speaker embedding: {voice} -> {spk_path}")
85+
else:
86+
logger.warning(f"Speaker embedding not found: {spk_path}, skipping preload")
87+
app.state.spk_emb = app.state.spk_emb_map.get("default") # Default embedding
88+
89+
# Request parameter whitelist
90+
ALLOWED_PARAMS = {"model", "input", "voice", "response_format", "speed", "stream", "output_format"}
91+
92+
class OpenAITTSRequest(BaseModel):
93+
"""OpenAI TTS request data model"""
94+
model: str = Field(..., description="Speech synthesis model, fixed as 'tts-1'")
95+
input: str = Field(..., description="Text content to synthesize", max_length=2048) # Length limit
96+
voice: Optional[str] = Field("default", description="Voice selection, supports: default, alloy, echo")
97+
response_format: Optional[str] = Field("mp3", description="Audio format: mp3, wav, ogg")
98+
speed: Optional[float] = Field(1.0, ge=0.5, le=2.0, description="Speed, range 0.5-2.0")
99+
stream: Optional[bool] = Field(False, description="Whether to stream")
100+
output_format: Optional[str] = "mp3" # Optional formats: mp3, wav, ogg
101+
extra_params: Dict[str, Optional[str]] = Field(default_factory=dict, description="Unsupported extra parameters")
102+
103+
@classmethod
104+
def validate_request(cls, request_data: Dict):
105+
"""Filter unsupported request parameters and unify model value to 'tts-1'"""
106+
request_data["model"] = "tts-1" # Unify model value
107+
unsupported_params = set(request_data.keys()) - ALLOWED_PARAMS
108+
if unsupported_params:
109+
logger.warning(f"Ignoring unsupported parameters: {unsupported_params}")
110+
return {key: request_data[key] for key in ALLOWED_PARAMS if key in request_data}
111+
112+
# Unified error response
113+
@app.exception_handler(Exception)
114+
async def custom_exception_handler(request, exc):
115+
"""Custom exception handler"""
116+
logger.error(f"Error: {str(exc)}")
117+
return JSONResponse(
118+
status_code=getattr(exc, "status_code", 500),
119+
content={"error": {"message": str(exc), "type": exc.__class__.__name__}}
120+
)
121+
122+
@app.post("/v1/audio/speech")
123+
async def generate_voice(request_data: Dict):
124+
"""Handle speech synthesis request"""
125+
request_data = OpenAITTSRequest.validate_request(request_data)
126+
request = OpenAITTSRequest(**request_data)
127+
128+
logger.info(f"Received request: text={request.input}..., voice={request.voice}, stream={request.stream}")
129+
130+
# Validate audio format
131+
if request.response_format not in ALLOWED_FORMATS:
132+
raise HTTPException(400, detail=f"Unsupported audio format: {request.response_format}, supported formats: {', '.join(ALLOWED_FORMATS)}")
133+
134+
# Load speaker embedding for the specified voice
135+
spk_emb = app.state.spk_emb_map.get(request.voice, app.state.spk_emb)
136+
137+
# Inference parameters
138+
params_infer_main = {
139+
"text": [request.input],
140+
"stream": request.stream,
141+
"lang": None,
142+
"skip_refine_text": True, # Do not use text refinement
143+
"refine_text_only": False,
144+
"use_decoder": True,
145+
"audio_seed": 12345678,
146+
# "text_seed": 87654321, # Random seed for text processing, used to control text refinement
147+
"do_text_normalization": True, # Perform text normalization
148+
"do_homophone_replacement": True, # Perform homophone replacement
149+
}
150+
151+
# Inference code parameters
152+
params_infer_code = app.state.chat.InferCodeParams(
153+
#prompt=f"[speed_{int(request.speed * 10)}]", # Convert to format supported by ChatTTS
154+
prompt="[speed_5]",
155+
top_P=0.5,
156+
top_K=10,
157+
temperature=0.1,
158+
repetition_penalty=1.1,
159+
max_new_token=2048,
160+
min_new_token=0,
161+
show_tqdm=True,
162+
ensure_non_empty=True,
163+
manual_seed=42,
164+
spk_emb=spk_emb,
165+
spk_smp=None,
166+
txt_smp=None,
167+
stream_batch=24,
168+
stream_speed=12000,
169+
pass_first_n_batches=2
170+
)
171+
172+
try:
173+
async with app.state.model_lock:
174+
wavs = app.state.chat.infer(
175+
text = params_infer_main["text"],
176+
stream = params_infer_main["stream"],
177+
lang = params_infer_main["lang"],
178+
skip_refine_text = params_infer_main["skip_refine_text"],
179+
use_decoder = params_infer_main["use_decoder"],
180+
do_text_normalization = params_infer_main["do_text_normalization"],
181+
do_homophone_replacement = params_infer_main['do_homophone_replacement'],
182+
# params_refine_text = params_refine_text,
183+
params_infer_code=params_infer_code,
184+
)
185+
except Exception as e:
186+
raise HTTPException(500, detail=f"Speech synthesis failed: {str(e)}")
187+
188+
def generate_wav_header(sample_rate=24000, bits_per_sample=16, channels=1):
189+
"""Generate WAV file header (without data length)"""
190+
header = bytearray()
191+
header.extend(b"RIFF")
192+
header.extend(b"\xFF\xFF\xFF\xFF") # File size unknown
193+
header.extend(b"WAVEfmt ")
194+
header.extend((16).to_bytes(4, "little")) # fmt chunk size
195+
header.extend((1).to_bytes(2, "little")) # PCM format
196+
header.extend((channels).to_bytes(2, "little")) # Channels
197+
header.extend((sample_rate).to_bytes(4, "little")) # Sample rate
198+
byte_rate = sample_rate * channels * bits_per_sample // 8
199+
header.extend((byte_rate).to_bytes(4, "little")) # Byte rate
200+
block_align = channels * bits_per_sample // 8
201+
header.extend((block_align).to_bytes(2, "little")) # Block align
202+
header.extend((bits_per_sample).to_bytes(2, "little")) # Bits per sample
203+
header.extend(b"data")
204+
header.extend(b"\xFF\xFF\xFF\xFF") # Data size unknown
205+
return bytes(header)
206+
207+
# Handle audio output format
208+
def convert_audio(wav, format):
209+
"""Convert audio format"""
210+
if format == "mp3":
211+
return pcm_arr_to_mp3_view(wav)
212+
elif format == "wav":
213+
return pcm_arr_to_wav_view(wav, include_header=False) # No header in streaming
214+
elif format == "ogg":
215+
return pcm_arr_to_ogg_view(wav)
216+
return pcm_arr_to_mp3_view(wav)
217+
218+
# Return streaming audio data
219+
if request.stream:
220+
first_chunk = True
221+
async def audio_stream():
222+
nonlocal first_chunk
223+
for wav in wavs:
224+
if request.response_format == "wav" and first_chunk:
225+
yield generate_wav_header() # Send WAV header
226+
first_chunk = False
227+
yield convert_audio(wav, request.response_format)
228+
media_type = "audio/wav" if request.response_format == "wav" else "audio/mpeg"
229+
return StreamingResponse(audio_stream(), media_type=media_type)
230+
231+
# Return audio file directly
232+
if request.response_format == 'wav':
233+
music_data = pcm_arr_to_wav_view(wavs[0])
234+
else:
235+
music_data = convert_audio(wavs[0], request.response_format)
236+
237+
return StreamingResponse(io.BytesIO(music_data), media_type="audio/mpeg", headers={
238+
"Content-Disposition": f"attachment; filename=output.{request.response_format}"
239+
})
240+
241+
@app.get("/health")
242+
async def health_check():
243+
"""Health check endpoint"""
244+
return {"status": "healthy", "model_loaded": bool(app.state.chat)}

openai_api.ipynb

Lines changed: 1720 additions & 0 deletions
Large diffs are not rendered by default.

tools/audio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .av import load_audio
2-
from .pcm import pcm_arr_to_mp3_view
2+
from .pcm import pcm_arr_to_mp3_view, pcm_arr_to_ogg_view, pcm_arr_to_wav_view
33
from .ffmpeg import has_ffmpeg_installed
44
from .np import float_to_int16

tools/audio/pcm.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,85 @@
11
import wave
22
from io import BytesIO
3-
43
import numpy as np
5-
64
from .np import float_to_int16
75
from .av import wav2
86

7+
def _pcm_to_wav_buffer(wav: np.ndarray, sample_rate: int = 24000) -> BytesIO:
8+
"""
9+
Convert PCM audio data to a WAV format byte stream (internal utility function).
910
10-
def pcm_arr_to_mp3_view(wav: np.ndarray):
11+
:param wav: PCM data, NumPy array, typically in float32 format.
12+
:param sample_rate: Sample rate (in Hz), defaults to 24000.
13+
:return: WAV format byte stream, stored in a BytesIO object.
14+
"""
15+
# Create an in-memory byte stream buffer
1116
buf = BytesIO()
17+
18+
# Open a WAV file stream in write mode
1219
with wave.open(buf, "wb") as wf:
13-
wf.setnchannels(1) # Mono channel
14-
wf.setsampwidth(2) # Sample width in bytes
15-
wf.setframerate(24000) # Sample rate in Hz
20+
# Set number of channels to 1 (mono)
21+
wf.setnchannels(1)
22+
# Set sample width to 2 bytes (16-bit)
23+
wf.setsampwidth(2)
24+
# Set sample rate
25+
wf.setframerate(sample_rate)
26+
# Convert PCM to 16-bit integer and write
1627
wf.writeframes(float_to_int16(wav))
28+
29+
# Reset buffer pointer to the beginning
1730
buf.seek(0, 0)
31+
return buf
32+
33+
def pcm_arr_to_mp3_view(wav: np.ndarray, sample_rate: int = 24000) -> memoryview:
34+
"""
35+
Convert PCM audio data to MP3 format.
36+
37+
:param wav: PCM data, NumPy array, typically in float32 format.
38+
:param sample_rate: Sample rate (in Hz), defaults to 24000.
39+
:return: MP3 format byte data, returned as a memoryview.
40+
"""
41+
# Get WAV format byte stream
42+
buf = _pcm_to_wav_buffer(wav, sample_rate)
43+
44+
# Create output buffer
1845
buf2 = BytesIO()
46+
# Convert WAV data to MP3
1947
wav2(buf, buf2, "mp3")
20-
buf.seek(0, 0)
48+
# Return MP3 data
2149
return buf2.getbuffer()
50+
51+
def pcm_arr_to_ogg_view(wav: np.ndarray, sample_rate: int = 24000) -> memoryview:
52+
"""
53+
Convert PCM audio data to OGG format (using Vorbis encoding).
54+
55+
:param wav: PCM data, NumPy array, typically in float32 format.
56+
:param sample_rate: Sample rate (in Hz), defaults to 24000.
57+
:return: OGG format byte data, returned as a memoryview.
58+
"""
59+
# Get WAV format byte stream
60+
buf = _pcm_to_wav_buffer(wav, sample_rate)
61+
62+
# Create output buffer
63+
buf2 = BytesIO()
64+
# Convert WAV data to OGG
65+
wav2(buf, buf2, "ogg")
66+
# Return OGG data
67+
return buf2.getbuffer()
68+
69+
def pcm_arr_to_wav_view(wav: np.ndarray, sample_rate: int = 24000, include_header: bool = True) -> memoryview:
70+
"""
71+
Convert PCM audio data to WAV format, with an option to include header.
72+
73+
:param wav: PCM data, NumPy array, typically in float32 format.
74+
:param sample_rate: Sample rate (in Hz), defaults to 24000.
75+
:param include_header: Whether to include WAV header, defaults to True.
76+
:return: WAV format or raw PCM byte data, returned as a memoryview.
77+
"""
78+
if include_header:
79+
# Get complete WAV byte stream
80+
buf = _pcm_to_wav_buffer(wav, sample_rate)
81+
return buf.getbuffer()
82+
else:
83+
# Return only converted 16-bit PCM data
84+
pcm_data = float_to_int16(wav)
85+
return memoryview(pcm_data.tobytes())

0 commit comments

Comments
 (0)