Skip to content

Commit d57dde1

Browse files
authored
1 parent c42852e commit d57dde1

File tree

7 files changed

+204
-101
lines changed

7 files changed

+204
-101
lines changed

ChatTTS/core.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from functools import partial
66
from typing import Literal
77
import tempfile
8+
from typing import Optional
89

910
import torch
1011
from omegaconf import OmegaConf
@@ -15,7 +16,7 @@
1516
from .model.gpt import GPT_warpper
1617
from .utils.gpu_utils import select_device
1718
from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map, HomophonesReplacer
18-
from .utils.io_utils import get_latest_modified_file
19+
from .utils.io import get_latest_modified_file, del_all
1920
from .infer.api import refine_text, infer_code
2021
from .utils.download import check_all_assets, download_all_assets
2122

@@ -91,17 +92,18 @@ def _load(
9192
decoder_config_path: str = None,
9293
decoder_ckpt_path: str = None,
9394
tokenizer_path: str = None,
94-
device: str = None,
95+
device: Optional[torch.device] = None,
9596
compile: bool = True,
9697
):
97-
if not device:
98-
device = select_device(4095)
98+
if device is None:
99+
device = select_device(4096)
99100
self.logger.log(logging.INFO, f'use {device}')
100-
101+
self.device = device
102+
101103
if vocos_config_path:
102104
vocos = Vocos.from_hparams(vocos_config_path).to(
103105
# vocos on mps will crash, use cpu fallback
104-
"cpu" if torch.backends.mps.is_available() else device
106+
"cpu" if "mps" in str(device) else device
105107
).eval()
106108
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
107109
vocos.load_state_dict(torch.load(vocos_ckpt_path))
@@ -118,7 +120,7 @@ def _load(
118120

119121
if gpt_config_path:
120122
cfg = OmegaConf.load(gpt_config_path)
121-
gpt = GPT_warpper(**cfg).to(device).eval()
123+
gpt = GPT_warpper(**cfg, device=device).eval()
122124
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
123125
gpt.load_state_dict(torch.load(gpt_ckpt_path))
124126
if compile and 'cuda' in str(device):
@@ -188,6 +190,7 @@ def _infer(
188190
text_tokens = refine_text(
189191
self.pretrain_models,
190192
text,
193+
device=self.device,
191194
**params_refine_text,
192195
)['ids']
193196
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
@@ -198,16 +201,28 @@ def _infer(
198201

199202
text = [params_infer_code.get('prompt', '') + i for i in text]
200203
params_infer_code.pop('prompt', '')
201-
result_gen = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder, stream=stream)
204+
result_gen = infer_code(
205+
self.pretrain_models,
206+
text,
207+
device=self.device,
208+
**params_infer_code,
209+
return_hidden=use_decoder,
210+
stream=stream,
211+
)
202212
if use_decoder:
203213
field = 'hiddens'
204214
docoder_name = 'decoder'
205215
else:
206216
field = 'ids'
207217
docoder_name = 'dvae'
208-
vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
209-
i.cpu() if torch.backends.mps.is_available() else i
210-
).cpu().numpy() for i in spec]
218+
if "mps" in str(self.device):
219+
vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
220+
i.cpu()
221+
).cpu().numpy() for i in spec]
222+
else:
223+
vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
224+
i
225+
).cpu().numpy() for i in spec]
211226
if stream:
212227

213228
length = 0
@@ -221,13 +236,20 @@ def _infer(
221236
if not len(chunk_data):
222237
continue
223238
self.logger.debug(f'new hidden {len(chunk_data)=}')
224-
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in [chunk_data]]
239+
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1).to(self.device)) for i in [chunk_data]]
240+
del_all(result)
241+
del chunk_data
225242
wav = vocos_decode(mel_spec)
243+
del_all(mel_spec)
226244
self.logger.debug(f'yield wav chunk {len(wav[0])=} {len(wav[0][0])=}')
227245
yield wav
228246
return
229-
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in next(result_gen)[field]]
230-
yield vocos_decode(mel_spec)
247+
result = next(result_gen)
248+
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1).to(self.device)) for i in result[field]]
249+
del_all(result)
250+
wav = vocos_decode(mel_spec)
251+
del_all(mel_spec)
252+
yield wav
231253

232254
def infer(
233255
self,

ChatTTS/infer/api.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import torch
33
import torch.nn.functional as F
44
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
5+
56
from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
7+
from ..utils.io import del_all
8+
from ..model.gpt import GPT_warpper
69

710
def infer_code(
811
models,
@@ -14,39 +17,42 @@ def infer_code(
1417
repetition_penalty = 1.05,
1518
max_new_token = 2048,
1619
stream=False,
20+
device="cpu",
1721
**kwargs
1822
):
19-
20-
device = next(models['gpt'].parameters()).device
21-
23+
24+
gpt: GPT_warpper = models['gpt']
25+
2226
if not isinstance(text, list):
2327
text = [text]
2428

2529
if not isinstance(temperature, list):
26-
temperature = [temperature] * models['gpt'].num_vq
30+
temperature = [temperature] * gpt.num_vq
2731

2832
if spk_emb is not None:
2933
text = [f'[Stts][spk_emb]{i}[Ptts]' for i in text]
3034
else:
3135
text = [f'[Stts][empty_spk]{i}[Ptts]' for i in text]
3236

33-
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
34-
input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq)
35-
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
36-
37-
inputs = {
38-
'input_ids': input_ids,
39-
'text_mask': text_mask,
40-
'attention_mask': text_token['attention_mask'],
41-
}
37+
text_token_tmp = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True)
38+
text_token = text_token_tmp.to(device)
39+
del text_token_tmp
40+
input_ids = text_token['input_ids'][...,None].expand(-1, -1, gpt.num_vq).to(gpt.device_gpt)
41+
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=gpt.device_gpt)
42+
43+
emb = gpt.get_emb(
44+
input_ids=input_ids,
45+
text_mask=text_mask,
46+
)
47+
del text_mask
4248

43-
emb = models['gpt'].get_emb(**inputs)
4449
if spk_emb is not None:
45-
emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \
46-
F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12)
47-
48-
num_code = models['gpt'].emb_code[0].num_embeddings - 1
49-
50+
n = F.normalize(spk_emb.to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12).to(gpt.device_gpt)
51+
emb[input_ids[..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = n
52+
del n
53+
54+
num_code = int(gpt.emb_code[0].num_embeddings - 1)
55+
5056
LogitsWarpers = []
5157
if top_P is not None:
5258
LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
@@ -58,10 +64,10 @@ def infer_code(
5864
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
5965
repetition_penalty, num_code, 16))
6066

61-
result = models['gpt'].generate(
62-
emb, inputs['input_ids'],
67+
result = gpt.generate(
68+
emb, input_ids,
6369
temperature = torch.tensor(temperature, device=device),
64-
attention_mask = inputs['attention_mask'],
70+
attention_mask = text_token['attention_mask'],
6571
LogitsWarpers = LogitsWarpers,
6672
LogitsProcessors = LogitsProcessors,
6773
eos_token = num_code,
@@ -71,6 +77,11 @@ def infer_code(
7177
**kwargs
7278
)
7379

80+
del_all(text_token)
81+
del emb, text_token, input_ids
82+
del_all(LogitsWarpers)
83+
del_all(LogitsProcessors)
84+
7485
return result
7586

7687

@@ -83,11 +94,12 @@ def refine_text(
8394
repetition_penalty = 1.0,
8495
max_new_token = 384,
8596
prompt = '',
97+
device="cpu",
8698
**kwargs
8799
):
88-
89-
device = next(models['gpt'].parameters()).device
90-
100+
101+
gpt: GPT_warpper = models['gpt']
102+
91103
if not isinstance(text, list):
92104
text = [text]
93105

@@ -97,11 +109,7 @@ def refine_text(
97109
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
98110
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
99111

100-
inputs = {
101-
'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq),
102-
'text_mask': text_mask,
103-
'attention_mask': text_token['attention_mask'],
104-
}
112+
input_ids = text_token['input_ids'][...,None].expand(-1, -1, gpt.num_vq)
105113

106114
LogitsWarpers = []
107115
if top_P is not None:
@@ -112,11 +120,17 @@ def refine_text(
112120
LogitsProcessors = []
113121
if repetition_penalty is not None and repetition_penalty != 1:
114122
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
115-
116-
result = models['gpt'].generate(
117-
models['gpt'].get_emb(**inputs), inputs['input_ids'],
123+
124+
emb = gpt.get_emb(
125+
input_ids=input_ids,
126+
text_mask=text_mask,
127+
)
128+
del text_mask
129+
130+
result = gpt.generate(
131+
emb, input_ids,
118132
temperature = torch.tensor([temperature,], device=device),
119-
attention_mask = inputs['attention_mask'],
133+
attention_mask = text_token['attention_mask'],
120134
LogitsWarpers = LogitsWarpers,
121135
LogitsProcessors = LogitsProcessors,
122136
eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
@@ -125,4 +139,10 @@ def refine_text(
125139
stream = False,
126140
**kwargs
127141
)
142+
143+
del_all(text_token)
144+
del emb, text_token, input_ids
145+
del_all(LogitsWarpers)
146+
del_all(LogitsProcessors)
147+
128148
return next(result)

0 commit comments

Comments
 (0)