3
3
import logging
4
4
import tempfile
5
5
from functools import partial
6
- from typing import Literal , Optional
6
+ from typing import Literal , Optional , List , Callable
7
7
8
+ import numpy as np
8
9
import torch
9
10
from omegaconf import OmegaConf
10
11
from vocos import Vocos
16
17
from .utils .infer import count_invalid_characters , detect_language , apply_character_map , apply_half2full_map , HomophonesReplacer
17
18
from .utils .io import get_latest_modified_file , del_all
18
19
from .infer .api import refine_text , infer_code
19
- from .utils .download import check_all_assets , download_all_assets
20
- from .utils .log import set_utils_logger
20
+ from .utils .dl import check_all_assets , download_all_assets
21
+ from .utils .log import logger as utils_logger
21
22
22
23
23
24
class Chat :
@@ -26,45 +27,45 @@ def __init__(self, logger=logging.getLogger(__name__)):
26
27
self .normalizer = {}
27
28
self .homophones_replacer = None
28
29
self .logger = logger
29
- set_utils_logger (logger )
30
+ utils_logger . set_logger (logger )
30
31
31
- def check_model (self , level = logging . INFO , use_decoder = False ):
32
+ def has_loaded (self , use_decoder = False ):
32
33
not_finish = False
33
- check_list = ['vocos' , ' gpt' , 'tokenizer' ]
34
+ check_list = ['gpt' , 'tokenizer' ]
34
35
35
36
if use_decoder :
36
37
check_list .append ('decoder' )
37
38
else :
38
39
check_list .append ('dvae' )
39
-
40
+
40
41
for module in check_list :
41
42
if module not in self .pretrain_models :
42
- self .logger .log ( logging . WARNING , f'{ module } not initialized.' )
43
+ self .logger .warn ( f'{ module } not initialized.' )
43
44
not_finish = True
44
-
45
+
46
+ if not hasattr (self , "_vocos_decode" ) or not hasattr (self , "vocos" ):
47
+ self .logger .warn ('vocos not initialized.' )
48
+ not_finish = True
49
+
45
50
if not not_finish :
46
- self .logger .log ( level , f'All initialized.' )
51
+ self .logger .info ( 'all models has been initialized.' )
47
52
48
53
return not not_finish
49
54
50
- def load_models (
55
+ def download_models (
51
56
self ,
52
57
source : Literal ['huggingface' , 'local' , 'custom' ]= 'local' ,
53
58
force_redownload = False ,
54
- compile : bool = True ,
55
59
custom_path : Optional [torch .serialization .FILE_LIKE ]= None ,
56
- device : Optional [torch .device ] = None ,
57
- coef : Optional [torch .Tensor ] = None ,
58
- ):
60
+ ) -> Optional [str ]:
59
61
if source == 'local' :
60
- torch .load
61
62
download_path = os .getcwd ()
62
63
if not check_all_assets (update = True ) or force_redownload :
63
64
with tempfile .TemporaryDirectory () as tmp :
64
65
download_all_assets (tmpdir = tmp )
65
66
if not check_all_assets (update = False ):
66
- self .logger .error ("counld not satisfy all assets needed." )
67
- return False
67
+ self .logger .error ("download to local path %s failed." , download_path )
68
+ return None
68
69
elif source == 'huggingface' :
69
70
hf_home = os .getenv ('HF_HOME' , os .path .expanduser ("~/.cache/huggingface" ))
70
71
try :
@@ -73,18 +74,38 @@ def load_models(
73
74
download_path = None
74
75
if download_path is None or force_redownload :
75
76
self .logger .log (logging .INFO , f'Download from HF: https://huggingface.co/2Noise/ChatTTS' )
76
- download_path = snapshot_download (repo_id = "2Noise/ChatTTS" , allow_patterns = ["*.pt" , "*.yaml" ])
77
+ try :
78
+ download_path = snapshot_download (repo_id = "2Noise/ChatTTS" , allow_patterns = ["*.pt" , "*.yaml" ])
79
+ except :
80
+ download_path = None
77
81
else :
78
- self .logger .log (logging .INFO , f'Load from cache: { download_path } ' )
82
+ self .logger .log (logging .INFO , f'load latest snapshot from cache: { download_path } ' )
83
+ if download_path is None :
84
+ self .logger .error ("download from huggingface failed." )
85
+ return None
79
86
elif source == 'custom' :
80
- self .logger .log (logging .INFO , f'Load from local: { custom_path } ' )
87
+ self .logger .log (logging .INFO , f'try to load from local: { custom_path } ' )
81
88
download_path = custom_path
89
+
90
+ return download_path
82
91
92
+ def load_models (
93
+ self ,
94
+ source : Literal ['huggingface' , 'local' , 'custom' ]= 'local' ,
95
+ force_redownload = False ,
96
+ compile : bool = True ,
97
+ custom_path : Optional [torch .serialization .FILE_LIKE ]= None ,
98
+ device : Optional [torch .device ] = None ,
99
+ coef : Optional [torch .Tensor ] = None ,
100
+ ) -> bool :
101
+ download_path = self .download_models (source , force_redownload , custom_path )
102
+ if download_path is None :
103
+ return False
83
104
return self ._load (
84
105
device = device , compile = compile , coef = coef ,
85
106
** {k : os .path .join (download_path , v ) for k , v in OmegaConf .load (os .path .join (download_path , 'config' , 'path.yaml' )).items ()},
86
107
)
87
-
108
+
88
109
def _load (
89
110
self ,
90
111
vocos_config_path : str = None ,
@@ -112,9 +133,17 @@ def _load(
112
133
).eval ()
113
134
assert vocos_ckpt_path , 'vocos_ckpt_path should not be None'
114
135
vocos .load_state_dict (torch .load (vocos_ckpt_path ))
115
- self .pretrain_models ['vocos' ] = vocos
136
+ self .vocos = vocos
137
+ if "mps" in str (self .device ):
138
+ self ._vocos_decode : Callable [[torch .Tensor ], np .ndarray ] = lambda spec : self .vocos .decode (
139
+ spec .cpu ()
140
+ ).cpu ().numpy ()
141
+ else :
142
+ self ._vocos_decode : Callable [[torch .Tensor ], np .ndarray ] = lambda spec : self .vocos .decode (
143
+ spec
144
+ ).cpu ().numpy ()
116
145
self .logger .log (logging .INFO , 'vocos loaded.' )
117
-
146
+
118
147
if dvae_config_path :
119
148
cfg = OmegaConf .load (dvae_config_path )
120
149
dvae = DVAE (** cfg , coef = coef ).to (device ).eval ()
@@ -157,8 +186,13 @@ def _load(
157
186
158
187
self .coef = coef
159
188
160
- return self .check_model ()
189
+ return self .has_loaded ()
161
190
191
+ def unload (self ):
192
+ logger = self .logger
193
+ del_all (self )
194
+ self .__init__ (logger )
195
+
162
196
def _infer (
163
197
self ,
164
198
text ,
@@ -173,23 +207,23 @@ def _infer(
173
207
do_homophone_replacement = True
174
208
):
175
209
176
- assert self .check_model (use_decoder = use_decoder )
210
+ assert self .has_loaded (use_decoder = use_decoder )
177
211
178
212
if not isinstance (text , list ):
179
213
text = [text ]
180
214
if do_text_normalization :
181
215
for i , t in enumerate (text ):
182
216
_lang = detect_language (t ) if lang is None else lang
183
- if self .init_normalizer (_lang ):
217
+ if self ._init_normalizer (_lang ):
184
218
text [i ] = self .normalizer [_lang ](t )
185
219
if _lang == 'zh' :
186
220
text [i ] = apply_half2full_map (text [i ])
187
221
for i , t in enumerate (text ):
188
222
invalid_characters = count_invalid_characters (t )
189
223
if len (invalid_characters ):
190
- self .logger .log ( logging . WARNING , f'Invalid characters found! : { invalid_characters } ' )
224
+ self .logger .warn ( f'Invalid characters found! : { invalid_characters } ' )
191
225
text [i ] = apply_character_map (t )
192
- if do_homophone_replacement and self .init_homophones_replacer ():
226
+ if do_homophone_replacement and self ._init_homophones_replacer ():
193
227
text [i ], replaced_words = self .homophones_replacer .replace (text [i ])
194
228
if replaced_words :
195
229
repl_res = ', ' .join ([f'{ _ [0 ]} ->{ _ [1 ]} ' for _ in replaced_words ])
@@ -205,64 +239,25 @@ def _infer(
205
239
text_tokens = refined .ids
206
240
text_tokens = [i [i < self .pretrain_models ['tokenizer' ].convert_tokens_to_ids ('[break_0]' )] for i in text_tokens ]
207
241
text = self .pretrain_models ['tokenizer' ].batch_decode (text_tokens )
208
- del_all ( refined )
242
+ refined . destroy ( )
209
243
if refine_text_only :
210
244
yield text
211
245
return
212
246
213
247
text = [params_infer_code .get ('prompt' , '' ) + i for i in text ]
214
248
params_infer_code .pop ('prompt' , '' )
215
- result_gen = infer_code (
249
+
250
+ length = [0 for _ in range (len (text ))]
251
+ for result in infer_code (
216
252
self .pretrain_models ,
217
253
text ,
218
254
device = self .device ,
219
255
** params_infer_code ,
220
256
return_hidden = use_decoder ,
221
257
stream = stream ,
222
- )
223
- if use_decoder :
224
- docoder_name = 'decoder'
225
- else :
226
- docoder_name = 'dvae'
227
- if "mps" in str (self .device ):
228
- vocos_decode = lambda spec : [self .pretrain_models ['vocos' ].decode (
229
- i .cpu ()
230
- ).cpu ().numpy () for i in spec ]
231
- else :
232
- vocos_decode = lambda spec : [self .pretrain_models ['vocos' ].decode (
233
- i
234
- ).cpu ().numpy () for i in spec ]
235
- if stream :
236
-
237
- length = 0
238
- for result in result_gen :
239
- x = result .hiddens if use_decoder else result .ids
240
- assert len (x ) == 1
241
- chunk_data = x [0 ]
242
- start_seek = length
243
- length = len (chunk_data )
244
- self .logger .debug (f'{ start_seek = } total len: { length } , new len: { length - start_seek = } ' )
245
- chunk_data = chunk_data [start_seek :]
246
- if not len (chunk_data ):
247
- continue
248
- self .logger .debug (f'new hidden { len (chunk_data )= } ' )
249
- mel_spec = [self .pretrain_models [docoder_name ](i [None ].permute (0 ,2 ,1 ).to (self .device )) for i in [chunk_data ]]
250
- del_all (result )
251
- del chunk_data
252
- del_all (x )
253
- wav = vocos_decode (mel_spec )
254
- del_all (mel_spec )
255
- self .logger .debug (f'yield wav chunk { len (wav [0 ])= } { len (wav [0 ][0 ])= } ' )
256
- yield wav
257
- return
258
- result = next (result_gen )
259
- x = result .hiddens if use_decoder else result .ids
260
- mel_spec = [self .pretrain_models [docoder_name ](i [None ].permute (0 ,2 ,1 ).to (self .device )) for i in x ]
261
- del_all (result )
262
- del_all (x )
263
- wav = vocos_decode (mel_spec )
264
- del_all (mel_spec )
265
- yield wav
258
+ ):
259
+ wav = self .decode_to_wavs (result , length , use_decoder )
260
+ yield wav
266
261
267
262
def infer (
268
263
self ,
@@ -294,13 +289,35 @@ def infer(
294
289
else :
295
290
return next (res_gen )
296
291
297
- def sample_random_speaker (self , ):
298
-
292
+ def sample_random_speaker (self ):
299
293
dim = self .pretrain_models ['gpt' ].gpt .layers [0 ].mlp .gate_proj .in_features
300
294
std , mean = self .pretrain_models ['spk_stat' ].chunk (2 )
301
295
return torch .randn (dim , device = std .device ) * std + mean
302
-
303
- def init_normalizer (self , lang ) -> bool :
296
+
297
+ def decode_to_wavs (self , result : GPT .GenerationOutputs , start_seeks : List [int ], use_decoder : bool ):
298
+ x = result .hiddens if use_decoder else result .ids
299
+ wavs : List [np .ndarray ] = []
300
+ for i , chunk_data in enumerate (x ):
301
+ start_seek = start_seeks [i ]
302
+ length = len (chunk_data )
303
+ if length <= start_seek :
304
+ wavs .append (None )
305
+ continue
306
+ start_seeks [i ] = length
307
+ chunk_data = chunk_data [start_seek :]
308
+ if use_decoder :
309
+ decoder = self .pretrain_models ['decoder' ]
310
+ else :
311
+ decoder = self .pretrain_models ['dvae' ]
312
+ mel_spec = decoder (chunk_data [None ].permute (0 ,2 ,1 ).to (self .device ))
313
+ del chunk_data
314
+ wavs .append (self ._vocos_decode (mel_spec ))
315
+ del_all (mel_spec )
316
+ result .destroy ()
317
+ del_all (x )
318
+ return wavs
319
+
320
+ def _init_normalizer (self , lang ) -> bool :
304
321
305
322
if lang in self .normalizer :
306
323
return True
@@ -335,16 +352,16 @@ def init_normalizer(self, lang) -> bool:
335
352
)
336
353
return False
337
354
338
- def init_homophones_replacer (self ):
355
+ def _init_homophones_replacer (self ):
339
356
if self .homophones_replacer :
340
357
return True
341
358
else :
342
359
try :
343
360
self .homophones_replacer = HomophonesReplacer (os .path .join (os .path .dirname (__file__ ), 'res' , 'homophones_map.json' ))
344
- self .logger .log (logging .INFO , 'homophones_replacer loaded.' )
361
+ self .logger .log (logging .INFO , 'successfully loaded HomophonesReplacer .' )
345
362
return True
346
363
except (IOError , json .JSONDecodeError ) as e :
347
- self .logger .log (logging .WARNING , f'Error loading homophones map: { e } ' )
364
+ self .logger .log (logging .WARNING , f'error loading homophones map: { e } ' )
348
365
except Exception as e :
349
- self .logger .log (logging .WARNING , f'Error loading homophones_replacer : { e } ' )
366
+ self .logger .log (logging .WARNING , f'error loading HomophonesReplacer : { e } ' )
350
367
return False
0 commit comments