Skip to content

Commit 2b45fde

Browse files
chore: update dependencies and enhance TTS Kokoro with all supported languages
feat: add support for Orpheus models Additional voices added to Orpheus and longer generation supported. pre-commit
1 parent 09c2c02 commit 2b45fde

File tree

8 files changed

+676
-247
lines changed

8 files changed

+676
-247
lines changed

mlx_audio/server.py

+97-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from fastapi.middleware.cors import CORSMiddleware
1313
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
1414
from fastapi.staticfiles import StaticFiles
15+
from huggingface_hub import list_repo_files
1516

1617

1718
# Configure logging
@@ -64,9 +65,10 @@ def tts_endpoint(
6465
voice: str = Form("af_heart"),
6566
speed: float = Form(1.0),
6667
model: str = Form("mlx-community/Kokoro-82M-4bit"),
68+
language: str = Form("american_english"),
6769
):
6870
"""
69-
POST an x-www-form-urlencoded form with 'text' (and optional 'voice', 'speed', and 'model').
71+
POST an x-www-form-urlencoded form with 'text' (and optional 'voice', 'speed', 'model', and 'language').
7072
We run TTS on the text, save the audio in a unique file,
7173
and return JSON with the filename so the client can retrieve it.
7274
"""
@@ -91,6 +93,10 @@ def tts_endpoint(
9193
"mlx-community/Kokoro-82M-6bit",
9294
"mlx-community/Kokoro-82M-8bit",
9395
"mlx-community/Kokoro-82M-bf16",
96+
"mlx-community/orpheus-3b-0.1-ft-bf16",
97+
"mlx-community/orpheus-3b-0.1-ft-8bit",
98+
"mlx-community/orpheus-3b-0.1-ft-6bit",
99+
"mlx-community/orpheus-3b-0.1-ft-4bit",
94100
]
95101
if model not in valid_models:
96102
return JSONResponse(
@@ -122,21 +128,42 @@ def tts_endpoint(
122128
output_path = os.path.join(OUTPUT_FOLDER, filename)
123129

124130
logger.debug(
125-
f"Generating TTS for text: '{text[:50]}...' with voice: {voice}, speed: {speed_float}, model: {model}"
131+
f"Generating TTS for text: '{text[:50]}...' with voice: {voice}, speed: {speed_float}, model: {model}, language: {language}"
126132
)
127133
logger.debug(f"Output file will be: {output_path}")
128134

135+
# Map language parameter to language code
136+
language_to_code = {
137+
"american_english": "a",
138+
"british_english": "b",
139+
"hindi": "h",
140+
"spanish": "s",
141+
"french": "f",
142+
"italian": "i",
143+
"brazilian_portuguese": "p",
144+
"japanese": "j",
145+
"mandarin_chinese": "c",
146+
}
147+
148+
# Set language code based on model type
149+
# For Orpheus models, always use "a" (American English)
150+
# For other models, use the language mapping
151+
if "orpheus" in model.lower():
152+
lang_code = "a" # Always use American English for Orpheus
153+
else:
154+
# Use language code from mapping, or fall back to first char of voice
155+
lang_code = language_to_code.get(language, voice[0])
156+
129157
# We'll use the high-level "model.generate" method:
130158
results = tts_model.generate(
131159
text=text,
132160
voice=voice,
133161
speed=speed_float,
134-
lang_code=voice[0],
162+
lang_code=lang_code,
135163
verbose=False,
136164
)
137165

138166
# We'll just gather all segments (if any) into a single wav
139-
# It's typical for multi-segment text to produce multiple wave segments:
140167
audio_arrays = []
141168
for segment in results:
142169
audio_arrays.append(segment.audio)
@@ -384,9 +411,63 @@ def open_output_folder():
384411
)
385412

386413

414+
def get_voice_names(repo_id):
415+
"""Fetches and returns a list of voice names (without extensions) from the given Hugging Face repository."""
416+
return [
417+
os.path.splitext(file.replace("voices/", ""))[0]
418+
for file in list_repo_files(repo_id)
419+
if file.startswith("voices/")
420+
]
421+
422+
423+
# Global variable to store the available voices
424+
available_voices = []
425+
426+
# List of supported models
427+
available_models = [
428+
{"id": "mlx-community/Kokoro-82M-4bit", "name": "Kokoro 82M 4bit"},
429+
{"id": "mlx-community/Kokoro-82M-6bit", "name": "Kokoro 82M 6bit"},
430+
{"id": "mlx-community/Kokoro-82M-8bit", "name": "Kokoro 82M 8bit"},
431+
{"id": "mlx-community/Kokoro-82M-bf16", "name": "Kokoro 82M bf16"},
432+
{"id": "mlx-community/orpheus-3b-0.1-ft-bf16", "name": "Orpheus 3B bf16"},
433+
{"id": "mlx-community/orpheus-3b-0.1-ft-8bit", "name": "Orpheus 3B 8bit"},
434+
{"id": "mlx-community/orpheus-3b-0.1-ft-6bit", "name": "Orpheus 3B 6bit"},
435+
{"id": "mlx-community/orpheus-3b-0.1-ft-4bit", "name": "Orpheus 3B 4bit"},
436+
]
437+
438+
439+
@app.get("/voices")
440+
def get_voices(repo_id: str = "hexgrad/Kokoro-82M", language: str = None):
441+
"""
442+
Return a list of available voice names.
443+
If language parameter is provided, filter voices starting with that language code.
444+
"""
445+
global available_voices
446+
447+
# For orpheus models, return a fixed list of voices
448+
if "orpheus" in repo_id.lower():
449+
voices = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
450+
return {"voices": voices}
451+
else:
452+
# Use the voices loaded during server startup
453+
voices = available_voices
454+
455+
# Filter voices by language code if provided
456+
if language:
457+
voices = [voice for voice in voices if voice.startswith(language)]
458+
459+
return {"voices": voices}
460+
461+
462+
@app.get("/models")
463+
def get_models():
464+
"""Return a list of available models."""
465+
return {"models": available_models}
466+
467+
387468
def setup_server():
388469
"""Setup the server by loading the model and creating the output directory."""
389-
global tts_model, audio_player, OUTPUT_FOLDER
470+
global tts_model, audio_player, OUTPUT_FOLDER, available_voices
390471

391472
# Make sure the output folder for generated TTS files exists
392473
try:
@@ -409,6 +490,17 @@ def setup_server():
409490
except Exception as fallback_error:
410491
logger.error(f"Error with fallback directory: {str(fallback_error)}")
411492

493+
# Load available voices
494+
try:
495+
default_repo = "hexgrad/Kokoro-82M"
496+
logger.debug(f"Loading voices from {default_repo}")
497+
available_voices = get_voice_names(default_repo)
498+
logger.debug(f"Successfully loaded {len(available_voices)} voices")
499+
except Exception as e:
500+
logger.error(f"Error loading voices: {str(e)}")
501+
logger.info("No voices loaded during startup")
502+
# We'll leave available_voices as an empty list
503+
412504
# Load the model if not already loaded
413505
if tts_model is None:
414506
try:

0 commit comments

Comments
 (0)