Skip to content

Commit 2ba7c6c

Browse files
authored
feat(dvae): expose coef for customize (#405)
and unify coef of dvae & decoder
1 parent b4c3cff commit 2ba7c6c

File tree

4 files changed

+36
-15
lines changed

4 files changed

+36
-15
lines changed

ChatTTS/core.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,15 @@ def load_models(
5151
self,
5252
source: Literal['huggingface', 'local', 'custom']='local',
5353
force_redownload=False,
54-
custom_path='<LOCAL_PATH>',
55-
**kwargs,
54+
compile: bool = True,
55+
custom_path: Optional[torch.serialization.FILE_LIKE]=None,
56+
device: Optional[torch.device] = None,
57+
coef: Optional[torch.Tensor] = None,
5658
):
5759
if source == 'local':
60+
torch.load
5861
download_path = os.getcwd()
59-
if not check_all_assets(update=True):
62+
if not check_all_assets(update=True) or force_redownload:
6063
with tempfile.TemporaryDirectory() as tmp:
6164
download_all_assets(tmpdir=tmp)
6265
if not check_all_assets(update=False):
@@ -77,7 +80,10 @@ def load_models(
7780
self.logger.log(logging.INFO, f'Load from local: {custom_path}')
7881
download_path = custom_path
7982

80-
return self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)
83+
return self._load(
84+
device=device, compile=compile, coef=coef,
85+
**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()},
86+
)
8187

8288
def _load(
8389
self,
@@ -92,6 +98,7 @@ def _load(
9298
tokenizer_path: str = None,
9399
device: Optional[torch.device] = None,
94100
compile: bool = True,
101+
coef: Optional[str] = None
95102
):
96103
if device is None:
97104
device = select_device(4096)
@@ -110,7 +117,8 @@ def _load(
110117

111118
if dvae_config_path:
112119
cfg = OmegaConf.load(dvae_config_path)
113-
dvae = DVAE(**cfg).to(device).eval()
120+
dvae = DVAE(**cfg, coef=coef).to(device).eval()
121+
coef = str(dvae)
114122
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
115123
dvae.load_state_dict(torch.load(dvae_ckpt_path))
116124
self.pretrain_models['dvae'] = dvae
@@ -134,7 +142,8 @@ def _load(
134142

135143
if decoder_config_path:
136144
cfg = OmegaConf.load(decoder_config_path)
137-
decoder = DVAE(**cfg).to(device).eval()
145+
decoder = DVAE(**cfg, coef=coef).to(device).eval()
146+
coef = str(decoder)
138147
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
139148
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
140149
self.pretrain_models['decoder'] = decoder
@@ -145,7 +154,9 @@ def _load(
145154
tokenizer.padding_side = 'left'
146155
self.pretrain_models['tokenizer'] = tokenizer
147156
self.logger.log(logging.INFO, 'tokenizer loaded.')
148-
157+
158+
self.coef = coef
159+
149160
return self.check_model()
150161

151162
def _infer(

ChatTTS/model/dvae.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import math
2-
from typing import List
2+
from typing import List, Optional
33

4+
import pybase16384 as b14
5+
import numpy as np
46
import torch
57
import torch.nn as nn
68
import torch.nn.functional as F
@@ -74,7 +76,7 @@ def __init__(self,
7476

7577
def _embed(self, x: torch.Tensor):
7678
if self.transpose:
77-
x.transpose_(1, 2)
79+
x = x.transpose(1, 2)
7880
"""
7981
x = rearrange(
8082
x, "b t (g r) -> g b t r", g = self.G, r = self.R,
@@ -84,9 +86,9 @@ def _embed(self, x: torch.Tensor):
8486
feat = self.quantizer.get_output_from_indices(x)
8587
return feat.transpose_(1,2) if self.transpose else feat
8688

87-
def forward(self, x,):
89+
def forward(self, x):
8890
if self.transpose:
89-
x.transpose_(1,2)
91+
x = x.transpose(1, 2)
9092
feat, ind = self.quantizer(x)
9193
"""
9294
ind = rearrange(
@@ -127,7 +129,7 @@ def __init__(self, idim: int, odim: int,
127129
for _ in range(n_layer)])
128130
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
129131

130-
def forward(self, input, conditioning=None):
132+
def forward(self, input: torch.Tensor, conditioning=None) -> torch.Tensor:
131133
# B, T, C
132134
x = input.transpose_(1, 2)
133135
y = self.conv_in(x)
@@ -142,17 +144,24 @@ def forward(self, input, conditioning=None):
142144

143145
class DVAE(nn.Module):
144146
def __init__(
145-
self, decoder_config, vq_config, dim=512
147+
self, decoder_config, vq_config, dim=512, coef: Optional[str] = None,
146148
):
147149
super().__init__()
148-
self.register_buffer('coef', torch.randn(1, 100, 1))
150+
if coef is None:
151+
coef = torch.rand(100)
152+
else:
153+
coef = torch.from_numpy(np.frombuffer(b14.decode_from_string(coef), dtype=np.float32))
154+
self.register_buffer('coef', coef.unsqueeze(0).unsqueeze_(2))
149155

150156
self.decoder = DVAEDecoder(**decoder_config)
151157
self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
152158
if vq_config is not None:
153159
self.vq_layer = GFSQ(**vq_config)
154160
else:
155161
self.vq_layer = None
162+
163+
def __repr__(self) -> str:
164+
return b14.encode_to_string(self.coef.cpu().numpy().astype(np.float32).tobytes())
156165

157166
def forward(self, inp: torch.Tensor) -> torch.Tensor:
158167

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ vocos
99
IPython
1010
gradio
1111
python-dotenv
12+
pybase16384
1213
pynini==2.1.5; sys_platform == 'linux'
1314
WeTextProcessing; sys_platform == 'linux'
1415
nemo_text_processing; sys_platform == 'linux'

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
install_requires=['omegaconf>=2.3.0',
77
'numpy<2.0.0',
88
'numba',
9+
'pybase16384',
910
'torch>=2.1.0',
1011
'tqdm',
1112
'vector_quantize_pytorch',
1213
'transformers>=4.41.1',
1314
'vocos',
14-
'IPython',
1515
], # 定义依赖哪些模块
1616
packages=find_packages(), # 系统自动从当前目录开始找包
1717
)

0 commit comments

Comments
 (0)