Skip to content

Commit e0a9e7e

Browse files
authored
fix(chat): model unload memory release (#418)
1 parent 21c8ecc commit e0a9e7e

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

ChatTTS/core.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _load(
129129
"cpu" if "mps" in str(device) else device
130130
).eval()
131131
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
132-
vocos.load_state_dict(torch.load(vocos_ckpt_path))
132+
vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True))
133133
self.vocos = vocos
134134
if "mps" in str(self.device):
135135
self._vocos_decode: Callable[[torch.Tensor], np.ndarray] = lambda spec: self.vocos.decode(
@@ -146,15 +146,15 @@ def _load(
146146
dvae = DVAE(**cfg, coef=coef).to(device).eval()
147147
coef = str(dvae)
148148
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
149-
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location=device))
149+
dvae.load_state_dict(torch.load(dvae_ckpt_path, weights_only=True, mmap=True))
150150
self.dvae = dvae
151151
self.logger.log(logging.INFO, 'dvae loaded.')
152152

153153
if gpt_config_path:
154154
cfg = OmegaConf.load(gpt_config_path)
155155
gpt = GPT(**cfg, device=device, logger=self.logger).eval()
156156
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
157-
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device))
157+
gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True))
158158
if compile and 'cuda' in str(device):
159159
try:
160160
gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
@@ -163,20 +163,20 @@ def _load(
163163
self.gpt = gpt
164164
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
165165
assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
166-
self.pretrain_models['spk_stat'] = torch.load(spk_stat_path, map_location=device).to(device)
166+
self.pretrain_models['spk_stat'] = torch.load(spk_stat_path, weights_only=True, mmap=True).to(device)
167167
self.logger.log(logging.INFO, 'gpt loaded.')
168168

169169
if decoder_config_path:
170170
cfg = OmegaConf.load(decoder_config_path)
171171
decoder = DVAE(**cfg, coef=coef).to(device).eval()
172172
coef = str(decoder)
173173
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
174-
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location=device))
174+
decoder.load_state_dict(torch.load(decoder_ckpt_path, weights_only=True, mmap=True))
175175
self.decoder = decoder
176176
self.logger.log(logging.INFO, 'decoder loaded.')
177177

178178
if tokenizer_path:
179-
tokenizer = torch.load(tokenizer_path, map_location=device)
179+
tokenizer = torch.load(tokenizer_path, map_location=device, mmap=True)
180180
tokenizer.padding_side = 'left'
181181
self.pretrain_models['tokenizer'] = tokenizer
182182
self.logger.log(logging.INFO, 'tokenizer loaded.')
@@ -187,7 +187,11 @@ def _load(
187187

188188
def unload(self):
189189
logger = self.logger
190-
del_all(self)
190+
del_all(self.pretrain_models)
191+
del_list = ["vocos", "_vocos_decode", 'gpt', 'decoder', 'dvae']
192+
for module in del_list:
193+
if hasattr(self, module):
194+
delattr(self, module)
191195
self.__init__(logger)
192196

193197
def _infer(

examples/web/funcs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import random
23
from typing import Optional
34

0 commit comments

Comments
 (0)