Skip to content

Commit ce1c962

Browse files
authored
fix: replace mispronounced words in TTS (#350)
* [Fix] Replace mispronounced words in TTS using hack method * [Refactor] Move homophones_map.json to res folder * [Refactor] Cache homophones_map.json and document source * [Refactor] document process in class docstring.
1 parent 0738ee6 commit ce1c962

File tree

3 files changed

+16488
-3
lines changed

3 files changed

+16488
-3
lines changed

ChatTTS/core.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
import os
3+
import json
34
import logging
45
from functools import partial
56
from omegaconf import OmegaConf
@@ -9,7 +10,7 @@
910
from .model.dvae import DVAE
1011
from .model.gpt import GPT_warpper
1112
from .utils.gpu_utils import select_device
12-
from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map
13+
from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map, HomophonesReplacer
1314
from .utils.io_utils import get_latest_modified_file
1415
from .infer.api import refine_text, infer_code
1516

@@ -22,6 +23,7 @@ class Chat:
2223
def __init__(self, ):
2324
self.pretrain_models = {}
2425
self.normalizer = {}
26+
self.homophones_replacer = None
2527
self.logger = logging.getLogger(__name__)
2628

2729
def check_model(self, level = logging.INFO, use_decoder = False):
@@ -136,6 +138,7 @@ def infer(
136138
use_decoder=True,
137139
do_text_normalization=True,
138140
lang=None,
141+
do_homophone_replacement=True
139142
):
140143

141144
assert self.check_model(use_decoder=use_decoder)
@@ -156,6 +159,10 @@ def infer(
156159
if len(invalid_characters):
157160
self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
158161
text[i] = apply_character_map(t)
162+
if do_homophone_replacement and self.init_homophones_replacer():
163+
text[i] = self.homophones_replacer.replace(t)
164+
if t != text[i]:
165+
self.logger.log(logging.INFO, f'Homophones replace: {t} -> {text[i]}')
159166

160167
if not skip_refine_text:
161168
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
@@ -219,3 +226,17 @@ def init_normalizer(self, lang) -> bool:
219226
'Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing',
220227
)
221228
return False
229+
230+
def init_homophones_replacer(self):
231+
if self.homophones_replacer:
232+
return True
233+
else:
234+
try:
235+
self.homophones_replacer = HomophonesReplacer(os.path.join(os.path.dirname(__file__), 'res', 'homophones_map.json'))
236+
self.logger.log(logging.INFO, 'homophones_replacer loaded.')
237+
return True
238+
except (IOError, json.JSONDecodeError) as e:
239+
self.logger.log(logging.WARNING, f'Error loading homophones map: {e}')
240+
except Exception as e:
241+
self.logger.log(logging.WARNING, f'Error loading homophones_replacer: {e}')
242+
return False

0 commit comments

Comments
 (0)