Skip to content

Commit

Permalink
feat(server): add input field for speaker_wav
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Feb 21, 2025
1 parent cb3614a commit dd5c9fa
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 26 deletions.
24 changes: 4 additions & 20 deletions TTS/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def create_argparser() -> argparse.ArgumentParser:

# TODO: set this from SpeakerManager
use_gst = api.synthesizer.tts_config.get("use_gst", False)
supports_cloning = api.synthesizer.tts_config.get("model", "") in ["xtts", "bark"]
app = Flask(__name__)


Expand All @@ -126,25 +127,6 @@ def style_wav_uri_to_dict(style_wav: str) -> str | dict:
return None


def speaker_wav_uri_to_dict(speaker_wav: str) -> str | dict:
"""Transform an uri speaker_wav, in either a string (path to wav file to be use for voice cloning)
or a dict (gst tokens/values to be use for voice cloning)
Args:
speaker_wav (str): uri
Returns:
Union[str, dict]: path to file (str) or gst speaker (dict)
"""
if speaker_wav:
if os.path.isfile(speaker_wav) and speaker_wav.endswith(".wav"):
return speaker_wav # local to the server

speaker_wav = json.loads(speaker_wav)
return speaker_wav
return None


@app.route("/")
def index():
return render_template(
Expand All @@ -155,6 +137,7 @@ def index():
speaker_ids=api.speakers,
language_ids=api.languages,
use_gst=use_gst,
supports_cloning=supports_cloning,
)


Expand Down Expand Up @@ -182,6 +165,8 @@ def tts():
speaker_idx = (
request.headers.get("speaker-id") or request.values.get("speaker_id", "") if api.is_multi_speaker else None
)
if speaker_idx == "":
speaker_idx = None
language_idx = (
request.headers.get("language-id") or request.values.get("language_id", "")
if api.is_multi_lingual
Expand All @@ -190,7 +175,6 @@ def tts():
style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "")
style_wav = style_wav_uri_to_dict(style_wav)
speaker_wav = request.headers.get("speaker-wav") or request.values.get("speaker_wav", "")
speaker_wav = speaker_wav_uri_to_dict(speaker_wav)

logger.info("Model input: %s", text)
logger.info("Speaker idx: %s", speaker_idx)
Expand Down
19 changes: 14 additions & 5 deletions TTS/server/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@

{%if use_gst%}
<input value='{"0": 0.1}' id="style_wav" placeholder="style wav (dict or path to wav).." size=45
type="text" name="style_wav">
type="text" name="style_wav"><br /><br />
{%endif%}

{%if supports_cloning%}
Reference audio:
<input id="speaker_wav" placeholder="path/to/speaker.wav" name="speaker_wav" accept=".wav"><br /><br />
{%endif%}

<input id="text" placeholder="Type here..." size=45 type="text" name="text">
Expand Down Expand Up @@ -114,14 +119,18 @@
q('#text').focus()
function do_tts(e) {
const text = q('#text').value
const speaker_id = getTextValue('#speaker_id')
const style_wav = getTextValue('#style_wav')
const speaker_wav = getTextValue('#speaker_wav')
let speaker_id = getTextValue('#speaker_id')
if (speaker_wav !== '') {
speaker_id = ''
}
const language_id = getTextValue('#language_id')
if (text) {
q('#message').textContent = 'Synthesizing...'
q('#speak-button').disabled = true
q('#audio').hidden = true
synthesize(text, speaker_id, style_wav, language_id)
synthesize(text, speaker_id, style_wav, speaker_wav, language_id)
}
e.preventDefault()
return false
Expand All @@ -132,8 +141,8 @@
do_tts(e)
}
})
function synthesize(text, speaker_id = "", style_wav = "", language_id = "") {
fetch(`/api/tts?text=${encodeURIComponent(text)}&speaker_id=${encodeURIComponent(speaker_id)}&style_wav=${encodeURIComponent(style_wav)}&language_id=${encodeURIComponent(language_id)}`, { cache: 'no-cache' })
function synthesize(text, speaker_id = "", style_wav = "", speaker_wav = "", language_id = "") {
fetch(`/api/tts?text=${encodeURIComponent(text)}&speaker_id=${encodeURIComponent(speaker_id)}&style_wav=${encodeURIComponent(style_wav)}&speaker_wav=${encodeURIComponent(speaker_wav)}&language_id=${encodeURIComponent(language_id)}`, { cache: 'no-cache' })
.then(function (res) {
if (!res.ok) throw Error(res.statusText)
return res.blob()
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwa
"top_p": config.top_p,
}
settings.update(kwargs) # allow overriding of preset settings with kwargs
if speaker_id is not None and speaker_id != "":
if speaker_id is not None:
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
settings.update(
Expand Down

0 comments on commit dd5c9fa

Please sign in to comment.