Skip to content

Commit 3fa6c37

Browse files
committed
Improved voice assigment. 1) Assign other voices from the same gender if avaiable instead of the same one 2) Check that the regional selection does filter all voices out
1 parent 28806b2 commit 3fa6c37

File tree

3 files changed

+182
-23
lines changed

3 files changed

+182
-23
lines changed

open_dubbing/main.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,17 @@ def log_error_and_exit(msg: str, code: ExitCode):
6767
exit(code)
6868

6969

70-
def check_languages(source_language, target_language, _tts, translation, _stt):
70+
def check_languages(
71+
source_language, target_language, _tts, translation, _stt, target_language_region
72+
):
7173
spt = _stt.get_languages()
7274
translation_languages = translation.get_language_pairs()
7375
logger().debug(f"check_languages. Pairs {len(translation_languages)}")
7476

7577
tts = _tts.get_languages()
7678

7779
if source_language not in spt:
78-
msg = f"source language '{source_language}' is not supported by the speech recognition system. Supported languages: '{spt}"
80+
msg = f"source language '{source_language}' is not supported by the speech recognition system. Supported languages: '{spt}'"
7981
log_error_and_exit(msg, ExitCode.INVALID_LANGUAGE_SPT)
8082

8183
pair = (source_language, target_language)
@@ -84,7 +86,15 @@ def check_languages(source_language, target_language, _tts, translation, _stt):
8486
log_error_and_exit(msg, ExitCode.INVALID_LANGUAGE_TRANS)
8587

8688
if target_language not in tts:
87-
msg = f"target language '{target_language}' is not supported by the text to speech system. Supported languages: '{tts}"
89+
msg = f"target language '{target_language}' is not supported by the text to speech system. Supported languages: '{tts}'"
90+
log_error_and_exit(msg, ExitCode.INVALID_LANGUAGE_TTS)
91+
92+
voices = _tts.get_available_voices(language_code=target_language)
93+
region_voices = _tts.get_voices_for_region_only(
94+
voices=voices, target_language_region=target_language_region
95+
)
96+
if len(region_voices) == 0:
97+
msg = f"filtering by '{target_language_region}' returns no voices for language '{target_language}' in the text to speech system"
8898
log_error_and_exit(msg, ExitCode.INVALID_LANGUAGE_TTS)
8999

90100

@@ -275,7 +285,14 @@ def main():
275285
args.translator, args.nllb_model, args.apertium_server, args.device
276286
)
277287

278-
check_languages(source_language, args.target_language, tts, translation, stt)
288+
check_languages(
289+
source_language,
290+
args.target_language,
291+
tts,
292+
translation,
293+
stt,
294+
args.target_language_region,
295+
)
279296

280297
if not os.path.exists(args.output_directory):
281298
os.makedirs(args.output_directory)

open_dubbing/text_to_speech.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,17 @@ def __init__(self):
4242
def get_available_voices(self, language_code: str) -> List[Voice]:
4343
pass
4444

45-
def get_voices_with_region_preference(
45+
def get_voices_for_region_only(
4646
self, *, voices: List[Voice], target_language_region: str
4747
) -> List[Voice]:
4848
if len(target_language_region) == 0:
4949
return voices
5050

51-
voices_copy = voices[:]
51+
voices_copy = []
5252

5353
for voice in voices:
5454
if voice.region.endswith(target_language_region):
55-
voices_copy.remove(voice)
56-
voices_copy.insert(0, voice)
55+
voices_copy.append(voice)
5756

5857
return voices_copy
5958

@@ -66,21 +65,39 @@ def assign_voices(
6665
) -> Mapping[str, str | None]:
6766

6867
voices = self.get_available_voices(target_language)
69-
voices = self.get_voices_with_region_preference(
68+
region_voices = self.get_voices_for_region_only(
7069
voices=voices, target_language_region=target_language_region
7170
)
7271

7372
voice_assignment = {}
73+
used_voices = set()
7474
for chunk in utterance_metadata:
7575
speaker_id = chunk["speaker_id"]
7676
if speaker_id in voice_assignment:
7777
continue
7878

7979
gender = chunk["gender"]
80-
for voice in voices:
81-
if voice.gender.lower() == gender.lower():
80+
for voice in region_voices: # Try to use an unused voice of the same gender
81+
if (
82+
voice.name not in used_voices
83+
and voice.gender.lower() == gender.lower()
84+
):
8285
voice_assignment[speaker_id] = voice.name
86+
used_voices.add(voice.name)
8387
break
88+
else:
89+
for (
90+
voice
91+
) in region_voices: # Try to use an already used voice of same gender
92+
if voice.gender.lower() == gender.lower():
93+
voice_assignment[speaker_id] = voice.name
94+
used_voices.add(voice.name)
95+
break
96+
else: # Try to use any other voice of any gender even if it used
97+
for voice in region_voices:
98+
voice_assignment[speaker_id] = voice.name
99+
used_voices.add(voice.name)
100+
break
84101

85102
logger().info(f"text_to_speech.assign_voices. Returns: {voice_assignment}")
86103
return voice_assignment

tests/text_to_speech_test.py

+137-12
Original file line numberDiff line numberDiff line change
@@ -353,36 +353,46 @@ def test_get_start_time_of_next_speech_utterance(
353353
)
354354
assert result == expected_result
355355

356-
def test_get_voices_with_region_filter(self):
356+
def test_get_voices_for_region_only(self):
357357
voices = [
358358
Voice(name="Voice1", gender="Male", region="US"),
359359
Voice(name="Voice2", gender="Female", region="UK"),
360360
Voice(name="Voice3", gender="Male", region="IN"),
361361
Voice(name="Voice4", gender="Female", region="IN"),
362362
]
363363

364-
result = TextToSpeechUT().get_voices_with_region_preference(
364+
result = TextToSpeechUT().get_voices_for_region_only(
365365
voices=voices, target_language_region="UK"
366366
)
367-
assert result[0].region == "UK"
367+
assert 1 == len(result)
368+
assert "UK" == result[0].region
368369

369-
result = TextToSpeechUT().get_voices_with_region_preference(
370+
result = TextToSpeechUT().get_voices_for_region_only(
370371
voices=voices, target_language_region="IN"
371372
)
372-
assert result[0].region == "IN"
373-
assert result[1].region == "IN"
374373

375-
result = TextToSpeechUT().get_voices_with_region_preference(
374+
assert 2 == len(result)
375+
assert "IN" == result[0].region
376+
assert "IN" == result[1].region
377+
378+
result = TextToSpeechUT().get_voices_for_region_only(
376379
voices=voices, target_language_region=""
377380
)
378-
assert result[0].region == "US"
381+
assert 4 == len(result)
382+
assert "US" == result[0].region
379383

380-
def test_assign_voices(self):
384+
@pytest.mark.parametrize(
385+
"target_language_region, expected_voices",
386+
[
387+
("IN", {1: "Voice3"}),
388+
("", {1: "Voice1"}),
389+
],
390+
)
391+
def test_assign_voices_single_male(self, target_language_region, expected_voices):
381392
tts = TextToSpeechUT()
382393

383394
utterance_metadata = [
384395
{
385-
"assigned_voice": "en_voice",
386396
"speaker_id": 1,
387397
"gender": "Male",
388398
}
@@ -401,9 +411,124 @@ def test_assign_voices(self):
401411
results = tts.assign_voices(
402412
utterance_metadata=utterance_metadata,
403413
target_language="",
404-
target_language_region="IN",
414+
target_language_region=target_language_region,
415+
)
416+
assert expected_voices == results
417+
418+
@pytest.mark.parametrize(
419+
"target_language_region, expected_voices",
420+
[
421+
("IN", {1: "Voice2"}),
422+
("", {1: "Voice1"}),
423+
],
424+
)
425+
def test_assign_voices_single_male_no_male_voice(
426+
self, target_language_region, expected_voices
427+
):
428+
tts = TextToSpeechUT()
429+
430+
utterance_metadata = [
431+
{
432+
"speaker_id": 1,
433+
"gender": "Male",
434+
}
435+
]
436+
437+
voices = [
438+
Voice(name="Voice1", gender="Female", region="UK"),
439+
Voice(name="Voice2", gender="Female", region="IN"),
440+
]
441+
442+
tts = TextToSpeechUT()
443+
444+
with patch.object(tts, "get_available_voices", return_value=voices):
445+
results = tts.assign_voices(
446+
utterance_metadata=utterance_metadata,
447+
target_language="",
448+
target_language_region=target_language_region,
449+
)
450+
assert expected_voices == results
451+
452+
@pytest.mark.parametrize(
453+
"target_language_region, expected_voices",
454+
[
455+
("IN", {1: "Voice3", 2: "Voice3"}),
456+
("", {1: "Voice1", 2: "Voice3"}),
457+
],
458+
)
459+
def test_assign_voices_single_two_males_single_voice(
460+
self, target_language_region, expected_voices
461+
):
462+
tts = TextToSpeechUT()
463+
464+
utterance_metadata = [
465+
{
466+
"speaker_id": 1,
467+
"gender": "Male",
468+
},
469+
{
470+
"speaker_id": 2,
471+
"gender": "Male",
472+
},
473+
]
474+
475+
voices = [
476+
Voice(name="Voice1", gender="Male", region="US"),
477+
Voice(name="Voice2", gender="Female", region="US"),
478+
Voice(name="Voice3", gender="Male", region="IN"),
479+
Voice(name="Voice4", gender="Female", region="IN"),
480+
]
481+
482+
tts = TextToSpeechUT()
483+
484+
with patch.object(tts, "get_available_voices", return_value=voices):
485+
results = tts.assign_voices(
486+
utterance_metadata=utterance_metadata,
487+
target_language="",
488+
target_language_region=target_language_region,
489+
)
490+
assert expected_voices == results
491+
492+
@pytest.mark.parametrize(
493+
"target_language_region, expected_voices",
494+
[
495+
("IN", {1: "Voice3", 2: "Voice5"}),
496+
("", {1: "Voice1", 2: "Voice3"}),
497+
],
498+
)
499+
def test_assign_voices_single_two_males_two_voices(
500+
self, target_language_region, expected_voices
501+
):
502+
tts = TextToSpeechUT()
503+
504+
utterance_metadata = [
505+
{
506+
"speaker_id": 1,
507+
"gender": "Male",
508+
},
509+
{
510+
"speaker_id": 2,
511+
"gender": "Male",
512+
},
513+
]
514+
515+
voices = [
516+
Voice(name="Voice1", gender="Male", region="US"),
517+
Voice(name="Voice2", gender="Female", region="UK"),
518+
Voice(name="Voice3", gender="Male", region="IN"),
519+
Voice(name="Voice4", gender="Female", region="IN"),
520+
Voice(name="Voice5", gender="Male", region="IN"),
521+
]
522+
523+
tts = TextToSpeechUT()
524+
525+
with patch.object(tts, "get_available_voices", return_value=voices):
526+
results = tts.assign_voices(
527+
utterance_metadata=utterance_metadata,
528+
target_language="",
529+
target_language_region=target_language_region,
405530
)
406-
assert {1: "Voice3"} == results
531+
assert expected_voices == results
407532

408533
def _get_update_utterance_metadata(self):
409534
return [

0 commit comments

Comments
 (0)