1
1
2
2
import os
3
+ import json
3
4
import logging
4
5
from functools import partial
5
6
from omegaconf import OmegaConf
9
10
from .model .dvae import DVAE
10
11
from .model .gpt import GPT_warpper
11
12
from .utils .gpu_utils import select_device
12
- from .utils .infer_utils import count_invalid_characters , detect_language , apply_character_map , apply_half2full_map
13
+ from .utils .infer_utils import count_invalid_characters , detect_language , apply_character_map , apply_half2full_map , HomophonesReplacer
13
14
from .utils .io_utils import get_latest_modified_file
14
15
from .infer .api import refine_text , infer_code
15
16
@@ -22,6 +23,7 @@ class Chat:
22
23
def __init__ (self , ):
23
24
self .pretrain_models = {}
24
25
self .normalizer = {}
26
+ self .homophones_replacer = None
25
27
self .logger = logging .getLogger (__name__ )
26
28
27
29
def check_model (self , level = logging .INFO , use_decoder = False ):
@@ -136,6 +138,7 @@ def infer(
136
138
use_decoder = True ,
137
139
do_text_normalization = True ,
138
140
lang = None ,
141
+ do_homophone_replacement = True
139
142
):
140
143
141
144
assert self .check_model (use_decoder = use_decoder )
@@ -156,6 +159,10 @@ def infer(
156
159
if len (invalid_characters ):
157
160
self .logger .log (logging .WARNING , f'Invalid characters found! : { invalid_characters } ' )
158
161
text [i ] = apply_character_map (t )
162
+ if do_homophone_replacement and self .init_homophones_replacer ():
163
+ text [i ] = self .homophones_replacer .replace (t )
164
+ if t != text [i ]:
165
+ self .logger .log (logging .INFO , f'Homophones replace: { t } -> { text [i ]} ' )
159
166
160
167
if not skip_refine_text :
161
168
text_tokens = refine_text (self .pretrain_models , text , ** params_refine_text )['ids' ]
@@ -219,3 +226,17 @@ def init_normalizer(self, lang) -> bool:
219
226
'Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing' ,
220
227
)
221
228
return False
229
+
230
+ def init_homophones_replacer (self ):
231
+ if self .homophones_replacer :
232
+ return True
233
+ else :
234
+ try :
235
+ self .homophones_replacer = HomophonesReplacer (os .path .join (os .path .dirname (__file__ ), 'res' , 'homophones_map.json' ))
236
+ self .logger .log (logging .INFO , 'homophones_replacer loaded.' )
237
+ return True
238
+ except (IOError , json .JSONDecodeError ) as e :
239
+ self .logger .log (logging .WARNING , f'Error loading homophones map: { e } ' )
240
+ except Exception as e :
241
+ self .logger .log (logging .WARNING , f'Error loading homophones_replacer: { e } ' )
242
+ return False
0 commit comments