Skip to content

Commit 46b007e

Browse files
authored
Merge pull request #415 from 2noise/optimzie
optimize: all
2 parents 2eb97d2 + 51c2118 commit 46b007e

File tree

16 files changed

+411
-268
lines changed

16 files changed

+411
-268
lines changed

ChatTTS/core.py

Lines changed: 100 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import logging
44
import tempfile
55
from functools import partial
6-
from typing import Literal, Optional
6+
from typing import Literal, Optional, List, Callable
77

8+
import numpy as np
89
import torch
910
from omegaconf import OmegaConf
1011
from vocos import Vocos
@@ -16,8 +17,8 @@
1617
from .utils.infer import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map, HomophonesReplacer
1718
from .utils.io import get_latest_modified_file, del_all
1819
from .infer.api import refine_text, infer_code
19-
from .utils.download import check_all_assets, download_all_assets
20-
from .utils.log import set_utils_logger
20+
from .utils.dl import check_all_assets, download_all_assets
21+
from .utils.log import logger as utils_logger
2122

2223

2324
class Chat:
@@ -26,45 +27,45 @@ def __init__(self, logger=logging.getLogger(__name__)):
2627
self.normalizer = {}
2728
self.homophones_replacer = None
2829
self.logger = logger
29-
set_utils_logger(logger)
30+
utils_logger.set_logger(logger)
3031

31-
def check_model(self, level = logging.INFO, use_decoder = False):
32+
def has_loaded(self, use_decoder = False):
3233
not_finish = False
33-
check_list = ['vocos', 'gpt', 'tokenizer']
34+
check_list = ['gpt', 'tokenizer']
3435

3536
if use_decoder:
3637
check_list.append('decoder')
3738
else:
3839
check_list.append('dvae')
39-
40+
4041
for module in check_list:
4142
if module not in self.pretrain_models:
42-
self.logger.log(logging.WARNING, f'{module} not initialized.')
43+
self.logger.warn(f'{module} not initialized.')
4344
not_finish = True
44-
45+
46+
if not hasattr(self, "_vocos_decode") or not hasattr(self, "vocos"):
47+
self.logger.warn('vocos not initialized.')
48+
not_finish = True
49+
4550
if not not_finish:
46-
self.logger.log(level, f'All initialized.')
51+
self.logger.info('all models has been initialized.')
4752

4853
return not not_finish
4954

50-
def load_models(
55+
def download_models(
5156
self,
5257
source: Literal['huggingface', 'local', 'custom']='local',
5358
force_redownload=False,
54-
compile: bool = True,
5559
custom_path: Optional[torch.serialization.FILE_LIKE]=None,
56-
device: Optional[torch.device] = None,
57-
coef: Optional[torch.Tensor] = None,
58-
):
60+
) -> Optional[str]:
5961
if source == 'local':
60-
torch.load
6162
download_path = os.getcwd()
6263
if not check_all_assets(update=True) or force_redownload:
6364
with tempfile.TemporaryDirectory() as tmp:
6465
download_all_assets(tmpdir=tmp)
6566
if not check_all_assets(update=False):
66-
self.logger.error("counld not satisfy all assets needed.")
67-
return False
67+
self.logger.error("download to local path %s failed.", download_path)
68+
return None
6869
elif source == 'huggingface':
6970
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
7071
try:
@@ -73,18 +74,38 @@ def load_models(
7374
download_path = None
7475
if download_path is None or force_redownload:
7576
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
76-
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
77+
try:
78+
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
79+
except:
80+
download_path = None
7781
else:
78-
self.logger.log(logging.INFO, f'Load from cache: {download_path}')
82+
self.logger.log(logging.INFO, f'load latest snapshot from cache: {download_path}')
83+
if download_path is None:
84+
self.logger.error("download from huggingface failed.")
85+
return None
7986
elif source == 'custom':
80-
self.logger.log(logging.INFO, f'Load from local: {custom_path}')
87+
self.logger.log(logging.INFO, f'try to load from local: {custom_path}')
8188
download_path = custom_path
89+
90+
return download_path
8291

92+
def load_models(
93+
self,
94+
source: Literal['huggingface', 'local', 'custom']='local',
95+
force_redownload=False,
96+
compile: bool = True,
97+
custom_path: Optional[torch.serialization.FILE_LIKE]=None,
98+
device: Optional[torch.device] = None,
99+
coef: Optional[torch.Tensor] = None,
100+
) -> bool:
101+
download_path = self.download_models(source, force_redownload, custom_path)
102+
if download_path is None:
103+
return False
83104
return self._load(
84105
device=device, compile=compile, coef=coef,
85106
**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()},
86107
)
87-
108+
88109
def _load(
89110
self,
90111
vocos_config_path: str = None,
@@ -112,9 +133,17 @@ def _load(
112133
).eval()
113134
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
114135
vocos.load_state_dict(torch.load(vocos_ckpt_path))
115-
self.pretrain_models['vocos'] = vocos
136+
self.vocos = vocos
137+
if "mps" in str(self.device):
138+
self._vocos_decode: Callable[[torch.Tensor], np.ndarray] = lambda spec: self.vocos.decode(
139+
spec.cpu()
140+
).cpu().numpy()
141+
else:
142+
self._vocos_decode: Callable[[torch.Tensor], np.ndarray] = lambda spec: self.vocos.decode(
143+
spec
144+
).cpu().numpy()
116145
self.logger.log(logging.INFO, 'vocos loaded.')
117-
146+
118147
if dvae_config_path:
119148
cfg = OmegaConf.load(dvae_config_path)
120149
dvae = DVAE(**cfg, coef=coef).to(device).eval()
@@ -157,8 +186,13 @@ def _load(
157186

158187
self.coef = coef
159188

160-
return self.check_model()
189+
return self.has_loaded()
161190

191+
def unload(self):
192+
logger = self.logger
193+
del_all(self)
194+
self.__init__(logger)
195+
162196
def _infer(
163197
self,
164198
text,
@@ -173,23 +207,23 @@ def _infer(
173207
do_homophone_replacement=True
174208
):
175209

176-
assert self.check_model(use_decoder=use_decoder)
210+
assert self.has_loaded(use_decoder=use_decoder)
177211

178212
if not isinstance(text, list):
179213
text = [text]
180214
if do_text_normalization:
181215
for i, t in enumerate(text):
182216
_lang = detect_language(t) if lang is None else lang
183-
if self.init_normalizer(_lang):
217+
if self._init_normalizer(_lang):
184218
text[i] = self.normalizer[_lang](t)
185219
if _lang == 'zh':
186220
text[i] = apply_half2full_map(text[i])
187221
for i, t in enumerate(text):
188222
invalid_characters = count_invalid_characters(t)
189223
if len(invalid_characters):
190-
self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
224+
self.logger.warn(f'Invalid characters found! : {invalid_characters}')
191225
text[i] = apply_character_map(t)
192-
if do_homophone_replacement and self.init_homophones_replacer():
226+
if do_homophone_replacement and self._init_homophones_replacer():
193227
text[i], replaced_words = self.homophones_replacer.replace(text[i])
194228
if replaced_words:
195229
repl_res = ', '.join([f'{_[0]}->{_[1]}' for _ in replaced_words])
@@ -205,64 +239,25 @@ def _infer(
205239
text_tokens = refined.ids
206240
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
207241
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
208-
del_all(refined)
242+
refined.destroy()
209243
if refine_text_only:
210244
yield text
211245
return
212246

213247
text = [params_infer_code.get('prompt', '') + i for i in text]
214248
params_infer_code.pop('prompt', '')
215-
result_gen = infer_code(
249+
250+
length = [0 for _ in range(len(text))]
251+
for result in infer_code(
216252
self.pretrain_models,
217253
text,
218254
device=self.device,
219255
**params_infer_code,
220256
return_hidden=use_decoder,
221257
stream=stream,
222-
)
223-
if use_decoder:
224-
docoder_name = 'decoder'
225-
else:
226-
docoder_name = 'dvae'
227-
if "mps" in str(self.device):
228-
vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
229-
i.cpu()
230-
).cpu().numpy() for i in spec]
231-
else:
232-
vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
233-
i
234-
).cpu().numpy() for i in spec]
235-
if stream:
236-
237-
length = 0
238-
for result in result_gen:
239-
x = result.hiddens if use_decoder else result.ids
240-
assert len(x) == 1
241-
chunk_data = x[0]
242-
start_seek = length
243-
length = len(chunk_data)
244-
self.logger.debug(f'{start_seek=} total len: {length}, new len: {length - start_seek = }')
245-
chunk_data = chunk_data[start_seek:]
246-
if not len(chunk_data):
247-
continue
248-
self.logger.debug(f'new hidden {len(chunk_data)=}')
249-
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1).to(self.device)) for i in [chunk_data]]
250-
del_all(result)
251-
del chunk_data
252-
del_all(x)
253-
wav = vocos_decode(mel_spec)
254-
del_all(mel_spec)
255-
self.logger.debug(f'yield wav chunk {len(wav[0])=} {len(wav[0][0])=}')
256-
yield wav
257-
return
258-
result = next(result_gen)
259-
x = result.hiddens if use_decoder else result.ids
260-
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1).to(self.device)) for i in x]
261-
del_all(result)
262-
del_all(x)
263-
wav = vocos_decode(mel_spec)
264-
del_all(mel_spec)
265-
yield wav
258+
):
259+
wav = self.decode_to_wavs(result, length, use_decoder)
260+
yield wav
266261

267262
def infer(
268263
self,
@@ -294,13 +289,35 @@ def infer(
294289
else:
295290
return next(res_gen)
296291

297-
def sample_random_speaker(self, ):
298-
292+
def sample_random_speaker(self):
299293
dim = self.pretrain_models['gpt'].gpt.layers[0].mlp.gate_proj.in_features
300294
std, mean = self.pretrain_models['spk_stat'].chunk(2)
301295
return torch.randn(dim, device=std.device) * std + mean
302-
303-
def init_normalizer(self, lang) -> bool:
296+
297+
def decode_to_wavs(self, result: GPT.GenerationOutputs, start_seeks: List[int], use_decoder: bool):
298+
x = result.hiddens if use_decoder else result.ids
299+
wavs: List[np.ndarray] = []
300+
for i, chunk_data in enumerate(x):
301+
start_seek = start_seeks[i]
302+
length = len(chunk_data)
303+
if length <= start_seek:
304+
wavs.append(None)
305+
continue
306+
start_seeks[i] = length
307+
chunk_data = chunk_data[start_seek:]
308+
if use_decoder:
309+
decoder = self.pretrain_models['decoder']
310+
else:
311+
decoder = self.pretrain_models['dvae']
312+
mel_spec = decoder(chunk_data[None].permute(0,2,1).to(self.device))
313+
del chunk_data
314+
wavs.append(self._vocos_decode(mel_spec))
315+
del_all(mel_spec)
316+
result.destroy()
317+
del_all(x)
318+
return wavs
319+
320+
def _init_normalizer(self, lang) -> bool:
304321

305322
if lang in self.normalizer:
306323
return True
@@ -335,16 +352,16 @@ def init_normalizer(self, lang) -> bool:
335352
)
336353
return False
337354

338-
def init_homophones_replacer(self):
355+
def _init_homophones_replacer(self):
339356
if self.homophones_replacer:
340357
return True
341358
else:
342359
try:
343360
self.homophones_replacer = HomophonesReplacer(os.path.join(os.path.dirname(__file__), 'res', 'homophones_map.json'))
344-
self.logger.log(logging.INFO, 'homophones_replacer loaded.')
361+
self.logger.log(logging.INFO, 'successfully loaded HomophonesReplacer.')
345362
return True
346363
except (IOError, json.JSONDecodeError) as e:
347-
self.logger.log(logging.WARNING, f'Error loading homophones map: {e}')
364+
self.logger.log(logging.WARNING, f'error loading homophones map: {e}')
348365
except Exception as e:
349-
self.logger.log(logging.WARNING, f'Error loading homophones_replacer: {e}')
366+
self.logger.log(logging.WARNING, f'error loading HomophonesReplacer: {e}')
350367
return False

ChatTTS/model/dvae.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -164,20 +164,21 @@ def __repr__(self) -> str:
164164
return b14.encode_to_string(self.coef.cpu().numpy().astype(np.float32).tobytes())
165165

166166
def forward(self, inp: torch.Tensor) -> torch.Tensor:
167+
with torch.no_grad():
167168

168-
if self.vq_layer is not None:
169-
vq_feats = self.vq_layer._embed(inp)
170-
else:
171-
vq_feats = inp.detach().clone()
169+
if self.vq_layer is not None:
170+
vq_feats = self.vq_layer._embed(inp)
171+
else:
172+
vq_feats = inp.detach().clone()
172173

173-
vq_feats = vq_feats.view(
174-
(vq_feats.size(0), 2, vq_feats.size(1)//2, vq_feats.size(2)),
175-
).permute(0, 2, 3, 1).flatten(2)
174+
vq_feats = vq_feats.view(
175+
(vq_feats.size(0), 2, vq_feats.size(1)//2, vq_feats.size(2)),
176+
).permute(0, 2, 3, 1).flatten(2)
176177

177-
dec_out = self.out_conv(
178-
self.decoder(
179-
input=vq_feats.transpose_(1, 2),
180-
).transpose_(1, 2),
181-
)
178+
dec_out = self.out_conv(
179+
self.decoder(
180+
input=vq_feats.transpose_(1, 2),
181+
).transpose_(1, 2),
182+
)
182183

183-
return torch.mul(dec_out, self.coef, out=dec_out)
184+
return torch.mul(dec_out, self.coef, out=dec_out)

0 commit comments

Comments
 (0)